def _generate_replay_buffer(self, rb_cls):
    stack_count = 4
    shape = (15, 15, stack_count)
    single_shape = (15, 15, 1)
    observation_spec = array_spec.ArraySpec(shape, np.int32, 'obs')
    time_step_spec = ts.time_step_spec(observation_spec)
    action_spec = policy_step.PolicyStep(array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=1, name='action'))
    self._trajectory_spec = trajectory.from_transition(
        time_step_spec, action_spec, time_step_spec)

    self._capacity = 32
    self._replay_buffer = rb_cls(
        data_spec=self._trajectory_spec, capacity=self._capacity)

    # Generate N frames: the value of pixels is the frame index.
    # The observations will be generated by stacking K frames out of those N,
    # generating some redundancies between the observations.
    single_frames = []
    frame_count = 100
    for k in range(frame_count):
      single_frames.append(np.full(single_shape, k, dtype=np.int32))

    # Add stack of frames to the replay buffer.
    time_steps = []
    for k in range(len(single_frames) - stack_count + 1):
      observation = np.concatenate(single_frames[k:k + stack_count], axis=-1)
      time_steps.append(ts.transition(observation, reward=0.0))

    self._transition_count = len(time_steps) - 1
    dummy_action = policy_step.PolicyStep(np.int32(0))
    for k in range(self._transition_count):
      self._replay_buffer.add_batch(nest_utils.batch_nested_array(
          trajectory.from_transition(
              time_steps[k], dummy_action, time_steps[k + 1])))
Пример #2
0
    def _collect_step(self, time_step, metric_observers, train=False):
        """Run a single step (or 2 steps on life loss) in the environment."""
        if train:
            policy = self._collect_policy
        else:
            policy = self._eval_policy

        with self._action_timer:
            action_step = policy.action(time_step)
        with self._step_timer:
            next_time_step = self._env.step(action_step.action)
            traj = trajectory.from_transition(time_step, action_step,
                                              next_time_step)

        if next_time_step.is_last() and not self.game_over():
            traj = traj._replace(discount=np.array([1.0], dtype=np.float32))

        if train:
            self._store_to_rb(traj)

        # When AtariPreprocessing.terminal_on_life_loss is True, we receive LAST
        # time_steps when lives are lost but the game is not over.In this mode, the
        # replay buffer and agent's policy must see the life loss as a LAST step
        # and the subsequent step as a FIRST step. However, we do not want to
        # actually terminate the episode and metrics should be computed as if all
        # steps were MID steps, since life loss is not actually a terminal event
        # (it is mostly a trick to make it easier to propagate rewards backwards by
        # shortening episode durations from the agent's perspective).
        if next_time_step.is_last() and not self.game_over():
            # Update metrics as if this is a mid-episode step.
            next_time_step = ts.transition(next_time_step.observation,
                                           next_time_step.reward)
            self._observe(
                metric_observers,
                trajectory.from_transition(time_step, action_step,
                                           next_time_step))

            # Produce the next step as if this is the first step of an episode and
            # store to RB as such. The next_time_step will be a MID time step.
            reward = time_step.reward
            time_step = ts.restart(next_time_step.observation)
            with self._action_timer:
                action_step = policy.action(time_step)
            with self._step_timer:
                next_time_step = self._env.step(action_step.action)
            if train:
                self._store_to_rb(
                    trajectory.from_transition(time_step, action_step,
                                               next_time_step))

            # Update metrics as if this is a mid-episode step.
            time_step = ts.transition(time_step.observation, reward)
            traj = trajectory.from_transition(time_step, action_step,
                                              next_time_step)

        self._observe(metric_observers, traj)

        return next_time_step
Пример #3
0
        def loop_body(counter, time_step, policy_state):
            """Runs a step in environment. While loop will call multiple times.

      Args:
        counter: Episode counters per batch index. Shape [batch_size].
        time_step: TimeStep tuple with elements shape [batch_size, ...].
        policy_state: Poicy state tensor shape [batch_size, policy_state_dim].
          Pass empty tuple for non-recurrent policies.
      Returns:
        loop_vars for next iteration of tf.while_loop.
      """
            action_step = self._policy.action(time_step, policy_state)
            policy_state = action_step.state
            next_time_step = self._env.step(action_step.action)

            traj = trajectory.from_transition(time_step, action_step,
                                              next_time_step)
            observer_ops = [observer(traj) for observer in self._observers]
            with tf.control_dependencies([tf.group(observer_ops)]):
                time_step, next_time_step, policy_state = nest.map_structure(
                    tf.identity, (time_step, next_time_step, policy_state))

            # While loop counter is only incremented for episode reset episodes.
            counter += tf.cast(traj.is_boundary(), dtype=tf.int32)

            return [counter, next_time_step, policy_state]
Пример #4
0
    def run(self, time_step, policy_state=()):
        """Run policy in environment given initial time_step and policy_state.

    Args:
      time_step: The initial time_step.
      policy_state: The initial policy_state.

    Returns:
      A tuple (final time_step, final policy_state).
    """
        num_steps = 0
        num_episodes = 0
        while num_steps < self._max_steps and num_episodes < self._max_episodes:
            action_step = self.policy.action(time_step, policy_state)
            next_time_step = self.env.step(action_step.action)

            traj = trajectory.from_transition(time_step, action_step,
                                              next_time_step)
            for observer in self.observers:
                observer(traj)

            num_episodes += np.sum(traj.is_last())
            num_steps += np.sum(~traj.is_boundary())

            time_step = next_time_step
            policy_state = action_step.state

        return time_step, policy_state
Пример #5
0
 def _setup_specs(self):
     self._policy_step_spec = policy_step.PolicyStep(
         action=self._action_spec,
         state=self._policy_state_spec,
         info=self._info_spec)
     self._trajectory_spec = trajectory.from_transition(
         self._time_step_spec, self._policy_step_spec, self._time_step_spec)
    def _fill_replay_buffer(self):
        # Generate N frames: the value of pixels is the frame index.
        # The observations will be generated by stacking K frames out of those N,
        # generating some redundancies between the observations.
        single_frames = []
        frame_count = 100
        for k in range(frame_count):
            single_frames.append(np.full(self._single_shape, k,
                                         dtype=np.int32))

        # Add stack of frames to the replay buffer.
        time_steps = []
        for k in range(len(single_frames) - self._stack_count + 1):
            observation = np.concatenate(single_frames[k:k +
                                                       self._stack_count],
                                         axis=-1)
            time_steps.append(ts.transition(observation, reward=0.0))

        self._transition_count = len(time_steps) - 1
        dummy_action = policy_step.PolicyStep(np.int32(0))
        for k in range(self._transition_count):
            self._replay_buffer.add_batch(
                nest_utils.batch_nested_array(
                    trajectory.from_transition(time_steps[k], dummy_action,
                                               time_steps[k + 1])))
Пример #7
0
def collect_step(environment, policy):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)

    return traj
Пример #8
0
def collect_step(environment, policy):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  replay_buffer.add_batch(traj)
Пример #9
0
def collect_step(env, time_step, py_policy, replay_buffer):
  """Steps the environment and collects experience into the replay buffer."""
  action_step = py_policy.action(time_step)
  next_time_step = env.step(action_step.action)
  if not time_step.is_last():
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    replay_buffer.add_batch(traj)
  return next_time_step
Пример #10
0
def collect_step(environment, policy):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)

    # Add trajectory to the replay buffer
    replay_buffer.add_batch(traj)
Пример #11
0
def make_replay_buffer(tf_env):
    """Default replay buffer factory."""

    time_step_spec = tf_env.time_step_spec()
    action_step_spec = policy_step.PolicyStep(
        tf_env.action_spec(), (), tensor_spec.TensorSpec((), tf.int32))
    trajectory_spec = trajectory.from_transition(time_step_spec,
                                                 action_step_spec,
                                                 time_step_spec)
    return tf_uniform_replay_buffer.TFUniformReplayBuffer(trajectory_spec,
                                                          batch_size=1)
Пример #12
0
 def _make_ppo_trajectory_spec(self, action_distribution_params_spec):
     # Make policy_step_spec with action_spec, empty tuple for policy_state, and
     # (act_log_prob_spec, value_pred_spec, action_distribution_params_spec) for
     # info.
     policy_step_spec = policy_step.PolicyStep(
         action=self.action_spec(),
         state=self._policy.policy_state_spec(),
         info=action_distribution_params_spec)
     trajectory_spec = trajectory.from_transition(self.time_step_spec(),
                                                  policy_step_spec,
                                                  self.time_step_spec())
     return trajectory_spec
Пример #13
0
 def _initial_collect(self):
   """Collect initial experience before training begins."""
   logging.info('Collecting initial experience...')
   time_step_spec = ts.time_step_spec(self._env.observation_spec())
   random_policy = random_py_policy.RandomPyPolicy(
       time_step_spec, self._env.action_spec())
   time_step = self._env.reset()
   while self._replay_buffer.size < self._initial_collect_steps:
     if self.game_over():
       time_step = self._env.reset()
     action_step = random_policy.action(time_step)
     next_time_step = self._env.step(action_step.action)
     self._replay_buffer.add_batch(trajectory.from_transition(
         time_step, action_step, next_time_step))
     time_step = next_time_step
   logging.info('Done.')
    def _create_replay_buffer(self, rb_cls):
        self._stack_count = 4
        self._single_shape = (15, 15, 1)
        shape = (15, 15, self._stack_count)
        observation_spec = array_spec.ArraySpec(shape, np.int32, 'obs')
        time_step_spec = ts.time_step_spec(observation_spec)
        action_spec = policy_step.PolicyStep(
            array_spec.BoundedArraySpec(shape=(),
                                        dtype=np.int32,
                                        minimum=0,
                                        maximum=1,
                                        name='action'))
        self._trajectory_spec = trajectory.from_transition(
            time_step_spec, action_spec, time_step_spec)

        self._capacity = 32
        self._replay_buffer = rb_cls(data_spec=self._trajectory_spec,
                                     capacity=self._capacity)
Пример #15
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=100000,
    fc_layer_params=(100,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    epsilon_greedy=0.1,
    replay_buffer_capacity=100000,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    learning_rate=1e-3,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for summaries and logging
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for DQN."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.contrib.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  # TODO(kbanoop): Figure out if it is possible to avoid the with block.
  with tf.contrib.summary.record_summaries_every_n_global_steps(
      summary_interval):

    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

    trajectory_spec = trajectory.from_transition(
        time_step=tf_env.time_step_spec(),
        action_step=policy_step.PolicyStep(action=tf_env.action_spec()),
        next_time_step=tf_env.time_step_spec())
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=trajectory_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    q_net = q_network.QNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=fc_layer_params)

    tf_agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
        epsilon_greedy=epsilon_greedy,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars)

    tf_agent.initialize()
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy()
    collect_policy = tf_agent.collect_policy()

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration)

    global_step = tf.train.get_or_create_global_step()

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    # Collect initial replay data.
    tf.logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.' % initial_collect_steps)
    dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=initial_collect_steps).run()

    metrics = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(metrics, global_step.numpy())

    time_step = None
    policy_state = ()

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience, train_step_counter=global_step)
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        tf.logging.info('step = %d, loss = %f', global_step.numpy(), train_loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        tf.logging.info('%.3f steps/sec' % steps_per_sec)
        tf.contrib.summary.scalar(name='global_steps/sec', tensor=steps_per_sec)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(metrics, global_step.numpy())
    return train_loss
Пример #16
0
 def metric_observer(time_step, action, next_time_step, policy_state):
     action_step = policy_step.PolicyStep(action, policy_state, ())
     traj = trajectory.from_transition(time_step, action_step,
                                       next_time_step)
     return metric(traj)
Пример #17
0
    def __init__(
            self,
            root_dir,
            env_name,
            num_iterations=200,
            max_episode_frames=108000,  # ALE frames
            terminal_on_life_loss=False,
            conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3),
                                                                  1)),
            fc_layer_params=(512, ),
            # Params for collect
            initial_collect_steps=80000,  # ALE frames
            epsilon_greedy=0.01,
            epsilon_decay_period=1000000,  # ALE frames
            replay_buffer_capacity=1000000,
            # Params for train
            train_steps_per_iteration=1000000,  # ALE frames
            update_period=16,  # ALE frames
            target_update_tau=1.0,
            target_update_period=32000,  # ALE frames
            batch_size=32,
            learning_rate=2.5e-4,
            gamma=0.99,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for eval
            do_eval=True,
            eval_steps_per_iteration=500000,  # ALE frames
            eval_epsilon_greedy=0.001,
            # Params for checkpoints, summaries, and logging
            log_interval=1000,
            summary_interval=1000,
            summaries_flush_secs=10,
            debug_summaries=False,
            summarize_grads_and_vars=False,
            eval_metrics_callback=None):
        """A simple Atari train and eval for DQN.

    Args:
      root_dir: Directory to write log files to.
      env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0).
      num_iterations: Number of train/eval iterations to run.
      max_episode_frames: Maximum length of a single episode, in ALE frames.
      terminal_on_life_loss: Whether to simulate an episode termination when a
        life is lost.
      conv_layer_params: Params for convolutional layers of QNetwork.
      fc_layer_params: Params for fully connected layers of QNetwork.
      initial_collect_steps: Number of frames to ALE frames to process before
        beginning to train. Since this is in ALE frames, there will be
        initial_collect_steps/4 items in the RB when training starts.
      epsilon_greedy: Final epsilon value to decay to for training.
      epsilon_decay_period: Period over which to decay epsilon, from 1.0 to
        epsilon_greedy (defined above).
      replay_buffer_capacity: Maximum number of items to store in the RB.
      train_steps_per_iteration: Number of ALE frames to run through for each
        iteration of training.
      update_period: Run a train operation every update_period ALE frames.
      target_update_tau: Coeffecient for soft target network updates (1.0 ==
        hard updates).
      target_update_period: Period, in ALE frames, to copy the live network to
        the target network.
      batch_size: Number of frames to include in each training batch.
      learning_rate: RMS optimizer learning rate.
      gamma: Discount for future rewards.
      reward_scale_factor: Scaling factor for rewards.
      gradient_clipping: Norm length to clip gradients.
      do_eval: If True, run an eval every iteration. If False, skip eval.
      eval_steps_per_iteration: Number of ALE frames to run through for each
        iteration of training.
      eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 ==
        totally greedy policy).
      log_interval: Log stats to the terminal every log_interval training
        steps.
      summary_interval: Write TF summaries every summary_interval training
        steps.
      summaries_flush_secs: Flush summaries to disk every summaries_flush_secs
        seconds.
      debug_summaries: If True, write additional summaries for debugging (see
        dqn_agent for which summaries are written).
      summarize_grads_and_vars: Include gradients in summaries.
      eval_metrics_callback: A callback function that takes (metric_dict,
        global_step) as parameters. Called after every eval with the results of
        the evaluation.
    """
        self._update_period = update_period / ATARI_FRAME_SKIP
        self._train_steps_per_iteration = (train_steps_per_iteration /
                                           ATARI_FRAME_SKIP)
        self._do_eval = do_eval
        self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP
        self._eval_epsilon_greedy = eval_epsilon_greedy
        self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP
        self._summary_interval = summary_interval
        self._num_iterations = num_iterations
        self._log_interval = log_interval
        self._eval_metrics_callback = eval_metrics_callback

        with gin.unlock_config():
            gin.bind_parameter('AtariPreprocessing.terminal_on_life_loss',
                               terminal_on_life_loss)

        root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(root_dir, 'train')
        eval_dir = os.path.join(root_dir, 'eval')

        train_summary_writer = tf.compat.v2.summary.create_file_writer(
            train_dir, flush_millis=summaries_flush_secs * 1000)
        train_summary_writer.set_as_default()
        self._train_summary_writer = train_summary_writer

        self._eval_summary_writer = None
        if self._do_eval:
            eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_summary_writer = eval_summary_writer
            self._eval_metrics = [
                py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                               buffer_size=np.inf),
                py_metrics.AverageEpisodeLengthMetric(
                    name='PhaseAverageEpisodeLength', buffer_size=np.inf),
            ]

        self._global_step = tf.compat.v1.train.get_or_create_global_step()
        with tf.compat.v2.summary.record_if(lambda: tf.math.equal(
                self._global_step % self._summary_interval, 0)):
            self._env = suite_atari.load(
                env_name,
                max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP,
                gym_env_wrappers=suite_atari.
                DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
            self._env = batched_py_environment.BatchedPyEnvironment(
                [self._env])

            observation_spec = tensor_spec.from_spec(
                self._env.observation_spec())
            time_step_spec = ts.time_step_spec(observation_spec)
            action_spec = tensor_spec.from_spec(self._env.action_spec())

            with tf.device('/cpu:0'):
                epsilon = tf.compat.v1.train.polynomial_decay(
                    1.0,
                    self._global_step,
                    epsilon_decay_period / ATARI_FRAME_SKIP /
                    self._update_period,
                    end_learning_rate=epsilon_greedy)

            with tf.device('/gpu:0'):
                optimizer = tf.compat.v1.train.RMSPropOptimizer(
                    learning_rate=learning_rate,
                    decay=0.95,
                    momentum=0.0,
                    epsilon=0.00001,
                    centered=True)
                q_net = AtariQNetwork(observation_spec,
                                      action_spec,
                                      conv_layer_params=conv_layer_params,
                                      fc_layer_params=fc_layer_params)
                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_net,
                    optimizer=optimizer,
                    epsilon_greedy=epsilon,
                    target_update_tau=target_update_tau,
                    target_update_period=(target_update_period /
                                          ATARI_FRAME_SKIP /
                                          self._update_period),
                    td_errors_loss_fn=dqn_agent.element_wise_huber_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=self._global_step)

                self._collect_policy = py_tf_policy.PyTFPolicy(
                    tf_agent.collect_policy)

                if self._do_eval:
                    self._eval_policy = py_tf_policy.PyTFPolicy(
                        epsilon_greedy_policy.EpsilonGreedyPolicy(
                            policy=tf_agent.policy,
                            epsilon=self._eval_epsilon_greedy))

                py_observation_spec = self._env.observation_spec()
                py_time_step_spec = ts.time_step_spec(py_observation_spec)
                py_action_spec = policy_step.PolicyStep(
                    self._env.action_spec())
                data_spec = trajectory.from_transition(py_time_step_spec,
                                                       py_action_spec,
                                                       py_time_step_spec)
                self._replay_buffer = (
                    py_hashed_replay_buffer.PyHashedReplayBuffer(
                        data_spec=data_spec, capacity=replay_buffer_capacity))

            with tf.device('/cpu:0'):
                ds = self._replay_buffer.as_dataset(
                    sample_batch_size=batch_size, num_steps=2).prefetch(4)
                # TODO(b/123242430): Add prefetch_to_device back here in order to
                # improve performance once errors are resolved.
                self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds)
                experience = self._ds_itr.get_next()

            with tf.device('/gpu:0'):
                self._train_op = tf_agent.train(experience)

                self._env_steps_metric = py_metrics.EnvironmentSteps()
                self._step_metrics = [
                    py_metrics.NumberOfEpisodes(),
                    self._env_steps_metric,
                ]
                self._train_metrics = self._step_metrics + [
                    py_metrics.AverageReturnMetric(buffer_size=10),
                    py_metrics.AverageEpisodeLengthMetric(buffer_size=10),
                ]
                # The _train_phase_metrics average over an entire train iteration,
                # rather than the rolling average of the last 10 episodes.
                self._train_phase_metrics = [
                    py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                                   buffer_size=np.inf),
                    py_metrics.AverageEpisodeLengthMetric(
                        name='PhaseAverageEpisodeLength', buffer_size=np.inf),
                ]
                self._iteration_metric = py_metrics.CounterMetric(
                    name='Iteration')

                # Summaries written from python should run every time they are
                # generated.
                with tf.compat.v2.summary.record_if(True):
                    self._steps_per_second_ph = tf.compat.v1.placeholder(
                        tf.float32, shape=(), name='steps_per_sec_ph')
                    self._steps_per_second_summary = tf.contrib.summary.scalar(
                        name='global_steps/sec',
                        tensor=self._steps_per_second_ph)

                    for metric in self._train_metrics:
                        metric.tf_summaries(step_metrics=self._step_metrics)

                    for metric in self._train_phase_metrics:
                        metric.tf_summaries(
                            step_metrics=(self._iteration_metric, ))
                    self._iteration_metric.tf_summaries()

                    if self._do_eval:
                        with eval_summary_writer.as_default():
                            for metric in self._eval_metrics:
                                metric.tf_summaries(
                                    step_metrics=(self._iteration_metric, ))

                self._train_checkpointer = common_utils.Checkpointer(
                    ckpt_dir=train_dir,
                    agent=tf_agent,
                    global_step=self._global_step,
                    optimizer=optimizer,
                    metrics=metric_utils.MetricsGroup(
                        self._train_metrics + self._train_phase_metrics +
                        [self._iteration_metric], 'train_metrics'))
                self._policy_checkpointer = common_utils.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'policy'),
                    policy=tf_agent.policy,
                    global_step=self._global_step)
                self._rb_checkpointer = common_utils.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
                    max_to_keep=1,
                    replay_buffer=self._replay_buffer)

                self._init_agent_op = tf_agent.initialize()