Esempio n. 1
0
    def _policy(self, agent: str,
                observation: types.NestedTensor) -> types.NestedTensor:
        """Agent specific policy function

        Args:
            agent (str): agent id
            observation (types.NestedTensor): observation tensor received from the
                environment.

        Raises:
            NotImplementedError: unknown action space

        Returns:
            types.NestedTensor: agent action
        """

        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)

        # index network either on agent type or on agent id
        agent_key = agent.split("_")[0] if self._shared_weights else agent

        # Compute the policy, conditioned on the observation.
        policy = self._policy_networks[agent_key](batched_observation)

        # TODO (dries): Make this support hybrid action spaces.
        if type(self._agent_specs[agent].actions) == BoundedArray:
            # Continuous action
            action = policy
        elif type(self._agent_specs[agent].actions) == DiscreteArray:
            action = tf.math.argmax(policy, axis=1)
        else:
            raise NotImplementedError

        return action, policy
Esempio n. 2
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_obs = tf2_utils.add_batch_dim(observation)

        # Initialize the RNN state if necessary.
        if self._state is None:
            self._state = self._network.initial_state(1)

        # Forward.
        policy_output, new_state = self._policy(batched_obs, self._state)

        # If the policy network parameterises a distribution, sample from it.
        def maybe_sample(output):
            if isinstance(output, tfd.Distribution):
                output = output.sample()
            return output

        policy_output = tree.map_structure(maybe_sample, policy_output)

        self._prev_state = self._state
        self._state = new_state

        # Convert to numpy and squeeze out the batch dimension.
        action = tf2_utils.to_numpy_squeeze(policy_output)

        return action
Esempio n. 3
0
  def test_snapshot_distribution(self):
    """Test that snapshotter correctly calls saves/restores snapshots."""
    # Create a test network.
    net1 = snt.Sequential([
        networks.LayerNormMLP([10, 10]),
        networks.MultivariateNormalDiagHead(1)
    ])
    spec = specs.Array([10], dtype=np.float32)
    tf2_utils.create_variables(net1, [spec])

    # Save the test network.
    directory = self.get_tempdir()
    objects_to_save = {'net': net1}
    snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory)
    snapshotter.save()

    # Reload the test network.
    net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
    inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

    with tf.GradientTape() as tape:
      dist1 = net1(inputs)
      loss1 = tf.math.reduce_sum(dist1.mean() + dist1.variance())
      grads1 = tape.gradient(loss1, net1.trainable_variables)

    with tf.GradientTape() as tape:
      dist2 = net2(inputs)
      loss2 = tf.math.reduce_sum(dist2.mean() + dist2.variance())
      grads2 = tape.gradient(loss2, net2.trainable_variables)

    assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
Esempio n. 4
0
    def _policy(
        self,
        agent: str,
        observation: types.NestedTensor,
    ) -> Tuple[types.NestedTensor, types.NestedTensor]:
        """Agent specific policy function

        Args:
            agent (str): agent id
            observation (types.NestedTensor): observation tensor received from the
                environment.

        Returns:
            Tuple[types.NestedTensor, types.NestedTensor]: log probabilities and action
        """

        # Index network either on agent type or on agent id.
        network_key = agent.split("_")[0] if self._shared_weights else agent

        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        observation = tf2_utils.add_batch_dim(observation.observation)

        # Compute the policy, conditioned on the observation.
        policy = self._policy_networks[network_key](observation)

        # Sample from the policy and compute the log likelihood.
        action = policy.sample()
        log_prob = policy.log_prob(action)

        # Cast for compatibility with reverb.
        # sample() returns a 'int32', which is a problem.
        if isinstance(policy, tfp.distributions.Categorical):
            action = tf.cast(action, "int64")

        return log_prob, action
Esempio n. 5
0
  def test_rnn_snapshot(self):
    """Test that snapshotter correctly calls saves/restores snapshots on RNNs."""
    # Create a test network.
    net = snt.LSTM(10)
    spec = specs.Array([10], dtype=np.float32)
    tf2_utils.create_variables(net, [spec])

    # Test that if you add some postprocessing without rerunning
    # create_variables, it still works.
    wrapped_net = snt.DeepRNN([net, lambda x: x])

    for net1 in [net, wrapped_net]:
      # Save the test network.
      directory = self.get_tempdir()
      objects_to_save = {'net': net1}
      snapshotter = tf2_savers.Snapshotter(objects_to_save, directory=directory)
      snapshotter.save()

      # Reload the test network.
      net2 = tf.saved_model.load(os.path.join(snapshotter.directory, 'net'))
      inputs = tf2_utils.add_batch_dim(tf2_utils.zeros_like(spec))

      with tf.GradientTape() as tape:
        outputs1, next_state1 = net1(inputs, net1.initial_state(1))
        loss1 = tf.math.reduce_sum(outputs1)
        grads1 = tape.gradient(loss1, net1.trainable_variables)

      with tf.GradientTape() as tape:
        outputs2, next_state2 = net2(inputs, net2.initial_state(1))
        loss2 = tf.math.reduce_sum(outputs2)
        grads2 = tape.gradient(loss2, net2.trainable_variables)

      assert np.allclose(outputs1, outputs2)
      assert np.allclose(tree.flatten(next_state1), tree.flatten(next_state2))
      assert all(tree.map_structure(np.allclose, list(grads1), list(grads2)))
Esempio n. 6
0
    def select_action(self,
                      rl_obs,
                      unused_norm_base_act,
                      prev_residual=None,
                      prev_exploration=False,
                      add_exploration=True,
                      collapse=False,
                      verbose=False):
        mean, std = None, None
        chose_exploration = False
        if collapse and self.rl_eval_policy is not None:
            # Call custom deterministic policy; overrides gripper exploration.
            batched_rl_obs = tf_utils.add_batch_dim(rl_obs)
            batched_residual = self.rl_eval_policy(batched_rl_obs)
            residual = batched_residual.numpy().squeeze(axis=0)
        else:
            residual = self.rl_agent.select_action(rl_obs)
            if add_exploration:
                if prev_exploration and np.random.rand() < self.sticky_rate:
                    residual[0] = prev_residual[0]
                    chose_exploration = True
                elif np.random.rand() < self.bernoulli_rate:
                    # TODO(minttu): Only explore open if gripper not fully opened (& vice
                    # versa).
                    residual[0] = (np.random.rand() < 0.5) * 4 - 2
                    chose_exploration = True
            if self._gaussian_residual:
                batched_rl_obs = tf_utils.add_batch_dim(rl_obs)
                mean, std = self.rl_policy_params(batched_rl_obs)
                if collapse:  # Collapse overrides gripper exploration.
                    if verbose:
                        print(
                            f'Collapsing {residual} to mean {mean} (std {std})'
                        )
                    residual = mean
                elif chose_exploration and verbose:
                    print(f'Exploring {residual} from mean {mean}, std {std})')
                elif verbose:
                    print(f'Drew {residual} from mean {mean}, std {std})')

        action = self.denormalize_flat(residual)
        residual_action = self.denormalize_flat(residual)
        base_action = np.zeros_like(residual)
        return (action, base_action, residual_action, residual,
                chose_exploration, mean, std)
Esempio n. 7
0
    def _policy(
        self,
        agent: str,
        observation: types.NestedTensor,
        legal_actions: types.NestedTensor,
        epsilon: tf.Tensor,
        fingerprint: Optional[tf.Tensor] = None,
    ) -> types.NestedTensor:
        """Agent specific policy function

        Args:
            agent (str): agent id
            observation (types.NestedTensor): observation tensor received from the
                environment.
            legal_actions (types.NestedTensor): actions allowed to be taken at the
                current observation.
            epsilon (tf.Tensor): value for epsilon greedy action selection.
            fingerprint (Optional[tf.Tensor], optional): policy fingerprints. Defaults
                to None.

        Returns:
            types.NestedTensor: agent action
        """

        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)
        batched_legals = tf2_utils.add_batch_dim(legal_actions)

        # index network either on agent type or on agent id
        agent_key = agent.split("_")[0] if self._shared_weights else agent

        # Compute the policy, conditioned on the observation and
        # possibly the fingerprint.
        if fingerprint is not None:
            q_values = self._q_networks[agent_key](batched_observation,
                                                   fingerprint)
        else:
            q_values = self._q_networks[agent_key](batched_observation)

        # select legal action
        action = self._action_selectors[agent_key](q_values,
                                                   batched_legals,
                                                   epsilon=epsilon)

        return action
Esempio n. 8
0
    def _policy(
        self,
        agent: str,
        observation: types.NestedTensor,
        state: types.NestedTensor,
        message: types.NestedTensor,
        legal_actions: types.NestedTensor,
        epsilon: tf.Tensor,
    ) -> types.NestedTensor:
        """Agent specific policy function

        Args:
            agent (str): agent id
            observation (types.NestedTensor): observation tensor received from the
                environment.
            state (types.NestedTensor): Recurrent network state.
            message (types.NestedTensor): received agent messsage.
            legal_actions (types.NestedTensor): actions allowed to be taken at the
                current observation.
            epsilon (tf.Tensor): value for epsilon greedy action selection.

        Returns:
            types.NestedTensor: action and new recurrent hidden state
        """

        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)
        batched_legals = tf2_utils.add_batch_dim(legal_actions)

        # index network either on agent type or on agent id
        agent_key = agent.split("_")[0] if self._shared_weights else agent

        # Compute the policy, conditioned on the observation.
        (q_values, m_values), new_state = self._q_networks[agent_key](
            batched_observation, state, message)

        # select legal action
        action = self._action_selectors[agent_key](q_values,
                                                   batched_legals,
                                                   epsilon=epsilon)

        return (action, m_values), new_state
Esempio n. 9
0
  def _policy(self, observation: types.NestedTensor) -> types.NestedTensor:
    # Add a dummy batch dimension and as a side effect convert numpy to TF.
    batched_observation = tf2_utils.add_batch_dim(observation)

    # Compute the policy, conditioned on the observation.
    policy = self._policy_network(batched_observation)

    # Sample from the policy if it is stochastic.
    action = policy.sample() if isinstance(policy, tfd.Distribution) else policy

    return action
Esempio n. 10
0
 def rl_policy_params(self, observation):
     if self.rl_observation_network_type is None:
         batched_observation = tf_utils.add_batch_dim(observation)
     else:
         batched_observation = (
             self.rl_agent._learner._observation_network(observation))  # pylint: disable=protected-access
     policy_distr = self.rl_agent._learner._policy_network(
         batched_observation)  # pylint: disable=protected-access
     mean = policy_distr.loc.numpy().squeeze(axis=0)
     std = policy_distr.scale.diag.numpy().squeeze(axis=0)
     return mean, std
Esempio n. 11
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_obs = tf2_utils.add_batch_dim(observation)

        # Forward the policy network.
        action = self._policy(batched_obs)

        # Convert to numpy and squeeze out the batch dimension.
        action = tf2_utils.to_numpy_squeeze(action)

        return action
Esempio n. 12
0
    def _policy(self, observation: types.NestedTensor,
                mask: types.NestedTensor) -> types.NestedTensor:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)

        # Compute the policy, conditioned on the observation.
        qs = self._policy_network(batched_observation)

        qs = qs * tf.cast(mask, dtype=tf.float32)
        # Sample from the policy if it is stochastic.
        action = trfl.epsilon_greedy(qs, epsilon=0.05).sample()

        return action
Esempio n. 13
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)

        # Compute the policy, conditioned on the observation.
        policy = self._policy_network(batched_observation)
        if self._deterministic_policy:
            action = policy.mean()
        else:
            action = policy.sample()
        self._log_prob = policy.log_prob(action)
        return tf2_utils.to_numpy_squeeze(action)
Esempio n. 14
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)

        # Compute the policy, conditioned on the observation.
        action, policy, log_prob = self._policy_network.getAll(
            batched_observation)

        self._prev_logP = log_prob
        self._prev_means = policy

        # Return a numpy array with squeezed out batch dimension.
        return tf2_utils.to_numpy_squeeze(action)
Esempio n. 15
0
def create_variables(
    network: snt.Module,
    input_spec: List[OLT],
) -> Optional[tf.TensorSpec]:
    """Builds the network with dummy inputs to create the necessary variables.
    Args:
      network: Sonnet Module whose variables are to be created.
      input_spec: list of input specs to the network. The length of this list
        should match the number of arguments expected by `network`.
    Returns:
      output_spec: only returns an output spec if the output is a tf.Tensor, else
          it doesn't return anything (None); e.g. if the output is a
          tfp.distributions.Distribution.
    """
    # Create a dummy observation with no batch dimension.
    dummy_input = [
        OLT(
            observation=zeros_like(in_spec.observation),
            legal_actions=ones_like(in_spec.legal_actions),
            terminal=zeros_like(in_spec.terminal),
        ) for in_spec in input_spec
    ]

    # If we have an RNNCore the hidden state will be an additional input.
    if isinstance(network, snt.RNNCore):
        initial_state = squeeze_batch_dim(network.initial_state(1))
        dummy_input += [initial_state]

    # Forward pass of the network which will create variables as a side effect.
    dummy_output = network(*add_batch_dim(dummy_input))

    # Evaluate the input signature by converting the dummy input into a
    # TensorSpec. We then save the signature as a property of the network. This is
    # done so that we can later use it when creating snapshots. We do this here
    # because the snapshot code may not have access to the precise form of the
    # inputs.
    input_signature = tree.map_structure(
        lambda t: tf.TensorSpec((None, ) + t.shape, t.dtype), dummy_input)
    network._input_signature = input_signature  # pylint: disable=protected-access

    def spec(output: tf.Tensor) -> tf.TensorSpec:
        # If the output is not a Tensor, return None as spec is ill-defined.
        if not isinstance(output, tf.Tensor):
            return None
        # If this is not a scalar Tensor, make sure to squeeze out the batch dim.
        if tf.rank(output) > 0:
            output = squeeze_batch_dim(output)
        return tf.TensorSpec(output.shape, output.dtype)

    return tree.map_structure(spec, dummy_output)
Esempio n. 16
0
    def _policy(
        self, observation: types.NestedTensor, state: types.NestedTensor,
        mask: types.NestedTensor
    ) -> Tuple[types.NestedTensor, types.NestedTensor]:

        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_observation = tf2_utils.add_batch_dim(observation)

        # Compute the policy, conditioned on the observation.
        qvals, new_state = self._network(batched_observation, state)

        # Sample from the policy if it is stochastic.
        action = trfl.epsilon_greedy(qvals,
                                     epsilon=0.05,
                                     legal_actions_mask=tf.cast(
                                         mask, dtype=tf.float32)).sample()
        return action, new_state
Esempio n. 17
0
  def select_action(self, observation: types.NestedArray) -> types.NestedArray:
    # Add a dummy batch dimension and as a side effect convert numpy to TF.
    batched_obs = tf2_utils.add_batch_dim(observation)

    if self._state is None:
      self._state = self._network.initial_state(1)

    # Forward.
    (logits, _), new_state = self._policy(batched_obs, self._state)

    self._prev_logits = logits
    self._prev_state = self._state
    self._state = new_state

    action = tfd.Categorical(logits).sample()
    action = tf2_utils.to_numpy_squeeze(action)

    return action
Esempio n. 18
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_obs = tf2_utils.add_batch_dim(observation)

        # Initialize the RNN state if necessary.
        if self._state is None:
            self._state = self._network.initial_state(1)

        # Forward.
        policy_output, new_state = self._policy(batched_obs, self._state)

        self._prev_state = self._state
        self._state = new_state

        # Convert to numpy and squeeze out the batch dimension.
        action = tf2_utils.to_numpy_squeeze(policy_output)

        return action
Esempio n. 19
0
  def step(self, action: types.Action):
    # Reset if required.
    if self._needs_reset:
      raise ValueError('Model must be reset with an initial timestep.')

    # Step the model.
    state, action = tf2_utils.add_batch_dim([self._state, action])
    new_state, reward, discount_logits = [
        x.numpy().squeeze(axis=0) for x in self._forward(state, action)
    ]
    discount = special.softmax(discount_logits)

    # Save the resulting state for the next step.
    self._state = new_state

    # We threshold discount on a given tolerance.
    if discount < self._terminal_tol:
      self._needs_reset = True
      return dm_env.termination(reward=reward, observation=self._state.copy())
    return dm_env.transition(reward=reward, observation=self._state.copy())
Esempio n. 20
0
    def select_action(self,
                      observation: types.NestedArray) -> types.NestedArray:
        # Add a dummy batch dimension and as a side effect convert numpy to TF.
        batched_obs = tf2_utils.add_batch_dim(observation)

        # Forward the policy network.
        policy_output = self._policy_network(batched_obs)

        # If the policy network parameterises a distribution, sample from it.
        def maybe_sample(output):
            if isinstance(output, tfd.Distribution):
                output = output.sample()
            return output

        policy_output = tree.map_structure(maybe_sample, policy_output)

        # Convert to numpy and squeeze out the batch dimension.
        action = tf2_utils.to_numpy_squeeze(policy_output)

        return action
Esempio n. 21
0
def cal_mse(value_func, policy_net, environment, mse_samples, discount):
    sample_count = 0
    actor = actors.FeedForwardActor(policy_network=policy_net)
    timestep = environment.reset()
    actor.observe_first(timestep)
    mse = 0.0
    while sample_count < mse_samples:
        current_obs = timestep.observation
        action = actor.select_action(current_obs)
        timestep = environment.step(action)
        actor.observe(action, next_timestep=timestep)
        next_obs = timestep.observation
        reward = timestep.reward

        if timestep.last():
            timestep = environment.reset()
            actor.observe_first(timestep)
            current_obs = tf2_utils.add_batch_dim(current_obs)
            action = tf2_utils.add_batch_dim(action)
            mse_one = (reward - value_func(current_obs, action))**2
            print(value_func(current_obs, action).numpy().squeeze())
            print(f'reward = {reward}')
            print('=====End Episode=====')

        else:
            next_action = tf2_utils.add_batch_dim(
                actor.select_action(next_obs))
            action = tf2_utils.add_batch_dim(action)
            current_obs = tf2_utils.add_batch_dim(current_obs)
            next_obs = tf2_utils.add_batch_dim(next_obs)
            mse_one = (reward + discount * value_func(next_obs, next_action) -
                       value_func(current_obs, action))**2
            print(value_func(current_obs, action).numpy().squeeze())
        mse = mse + mse_one.numpy()
        sample_count += 1
    return mse.squeeze() / mse_samples
Esempio n. 22
0
    def select_action(self,
                      rl_obs,
                      norm_base_act,
                      full_obs=None,
                      prev_residual=None,
                      prev_exploration=False,
                      add_exploration=True,
                      collapse=False,
                      verbose=False):
        mean, std = None, None
        chose_exploration = False
        if collapse and self.rl_eval_policy is not None:
            # Call custom deterministic policy; overrides gripper exploration.
            batched_rl_obs = tf_utils.add_batch_dim(rl_obs)
            batched_residual = self.rl_eval_policy(batched_rl_obs)
            residual = batched_residual.numpy().squeeze(axis=0)
        else:
            residual = self.rl_agent.select_action(rl_obs)
            if add_exploration:
                if prev_exploration and np.random.rand() < self.sticky_rate:
                    residual[0] = prev_residual[0]
                    chose_exploration = True
                elif np.random.rand() < self.bernoulli_rate:
                    # Only explore open if gripper not currently opened (& vice versa).
                    grip_state = full_obs['grip_state']
                    if grip_state < 0.4:
                        residual[0] = -2  # Close
                    else:
                        residual[0] = 2  # Open
                    # residual[0] = (np.random.rand() < 0.5) * 4 - 2
                    chose_exploration = True
            if self._gaussian_residual:
                mean, std = self.rl_policy_params(rl_obs)
                if collapse:  # Collapse overrides gripper exploration.
                    if verbose:
                        print(
                            f'Collapsing {residual} to mean {mean} (std {std})'
                        )
                    residual = mean
                elif chose_exploration and verbose:
                    print(f'Exploring {residual} from mean {mean}, std {std})')
                elif verbose:
                    print(f'Drew {residual} from mean {mean}, std {std})')

        base_action = self.denormalize_flat_base_action(norm_base_act)
        residual_action = self.denormalize_flat(residual)
        if self.action_space.norm_type != self.base_agent.action_space.norm_type:
            # Normalize each action separately. denorm(r) + denorm(b).
            # Makes sense with residual_norm == centered
            if isinstance(base_action, dict):
                action = {
                    k: v + residual_action[k]
                    for k, v in base_action.items()
                }
            else:
                action = base_action + residual_action
        else:
            # Normalize once only. denorm(r + b).
            # Makes sense with base_norm == residual_norm == zeromean_unitvar.
            norm_act = norm_base_act + residual
            action = self.denormalize_flat(norm_act)
        return (action, base_action, residual_action, residual,
                chose_exploration, mean, std)
Esempio n. 23
0
def _generate_data(
    policy_net,
    environment,
    n_samples,
    batch_size,
    shuffle,
    include_terminal=False,  # Include terminal absorbing state.
    ignore_d_tm1=False  # Set d_tm1 as constant 1.0 if True.
):
    sample_count = 0
    actor = actors.FeedForwardActor(policy_network=policy_net)
    timestep = environment.reset()
    actor.observe_first(timestep)

    current_obs_list = []
    action_list = []
    next_obs_list = []
    reward_list = []
    discount_list = []
    nonterminal_list = []
    while sample_count < n_samples:
        current_obs = timestep.observation
        action = actor.select_action(current_obs)
        timestep = environment.step(action)
        actor.observe(action, next_timestep=timestep)
        next_obs = timestep.observation
        reward = timestep.reward
        discount = np.array(1.0, dtype=np.float32)
        if timestep.last() and not include_terminal:
            discount = np.array(0.0, dtype=np.float32)

        current_obs_list.append(tf2_utils.add_batch_dim(current_obs))
        action_list.append(tf2_utils.add_batch_dim(action))
        reward_list.append(tf2_utils.add_batch_dim(reward))
        discount_list.append(tf2_utils.add_batch_dim(discount))
        next_obs_list.append(tf2_utils.add_batch_dim(next_obs))
        nonterminal_list.append(
            tf2_utils.add_batch_dim(np.array(1.0, dtype=np.float32)))

        if timestep.last():
            if include_terminal:
                # Make another transition tuple from s, a -> s, a with 0 reward.
                current_obs = next_obs
                # action = actor.select_action(current_obs)
                reward = np.zeros_like(timestep.reward)
                discount = np.array(1.0, dtype=np.float32)
                next_obs = current_obs

                if ignore_d_tm1:
                    d_tm1 = np.array(1.0, dtype=np.float32)
                else:
                    d_tm1 = np.array(0.0, dtype=np.float32)

                for i in range(environment.action_spec().num_values):
                    action_ = np.array(i, dtype=action.dtype).reshape(
                        action.shape)

                    current_obs_list.append(
                        tf2_utils.add_batch_dim(current_obs))
                    action_list.append(tf2_utils.add_batch_dim(action_))
                    reward_list.append(tf2_utils.add_batch_dim(reward))
                    discount_list.append(tf2_utils.add_batch_dim(discount))
                    next_obs_list.append(tf2_utils.add_batch_dim(next_obs))
                    nonterminal_list.append(tf2_utils.add_batch_dim(d_tm1))

            timestep = environment.reset()
            actor.observe_first(timestep)

        sample_count += 1

    current_obs_data = tf.concat(current_obs_list, axis=0)
    action_data = tf.concat(action_list, axis=0)
    next_obs_data = tf.concat(next_obs_list, axis=0)
    reward_data = tf.concat(reward_list, axis=0)
    discount_data = tf.concat(discount_list, axis=0)
    nonterminal_data = tf.concat(nonterminal_list, axis=0)

    dataset = tf.data.Dataset.from_tensor_slices((
        current_obs_data,
        action_data,
        reward_data,
        discount_data,
        next_obs_data,
        # The last action is not valid
        # and should not be used.
        action_data,
        nonterminal_data))

    def _reverb_sample(*data_tuple):
        info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
                                 probability=tf.constant(1.0, tf.float64),
                                 table_size=tf.constant(0, tf.int64),
                                 priority=tf.constant(1.0, tf.float64))
        return reverb.ReplaySample(info=info, data=data_tuple)

    dataset = dataset.map(_reverb_sample,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    dataset = dataset.cache()
    if shuffle:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset