Exemplo n.º 1
0
def main(_):
  while True:
    try:
      # Client to communicate with the learner.
      client = grpc.Client(FLAGS.server_address)

      env = config.create_environment(FLAGS.task)

      # Unique ID to identify a specific run of an actor.
      run_id = np.random.randint(np.iinfo(np.int64).max)
      observation = env.reset()
      reward = 0.0
      raw_reward = 0.0
      done = False

      while True:
        env_output = utils.EnvOutput(reward, done, np.array(observation))
        action = client.inference((FLAGS.task, run_id, env_output, raw_reward))
        observation, reward, done, info = env.step(action.numpy())
        raw_reward = float(info.get('score_reward', reward))

        if done:
          observation = env.reset()
    except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
      logging.exception(e)
      env.close()
Exemplo n.º 2
0
 def _create_env_output(self, batch_size, unroll_length):
   return utils.EnvOutput(
       reward=tf.random.uniform([unroll_length, batch_size]),
       done=tf.cast(tf.random.uniform([unroll_length, batch_size],
                                      maxval=2, dtype=tf.int32),
                    tf.bool),
       observation=self._random_obs(batch_size, unroll_length))
Exemplo n.º 3
0
def actor_loop(create_env_fn):
  """Main actor loop.

  Args:
    create_env_fn: Callable (taking the task ID as argument) that must return a
      newly created environment.
  """
  logging.info('Starting actor loop')
  if are_summaries_enabled():
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.logdir, 'actor_{}'.format(FLAGS.task)),
        flush_millis=20000, max_queue=1000)
    timer_cls = profiling.ExportingTimer
  else:
    summary_writer = tf.summary.create_noop_writer()
    timer_cls = utils.nullcontext

  actor_step = 0
  with summary_writer.as_default():
    while True:
      try:
        # Client to communicate with the learner.
        client = grpc.Client(FLAGS.server_address)

        env = create_env_fn(FLAGS.task)

        # Unique ID to identify a specific run of an actor.
        run_id = np.random.randint(np.iinfo(np.int64).max)
        observation = env.reset()
        reward = 0.0
        raw_reward = 0.0
        done = False

        while True:
          tf.summary.experimental.set_step(actor_step)
          env_output = utils.EnvOutput(reward, done, observation)
          with timer_cls('actor/elapsed_inference_s', 1000):
            action = client.inference(
                (FLAGS.task, run_id, env_output, raw_reward))
          with timer_cls('actor/elapsed_env_step_s', 1000):
            observation, reward, done, info = env.step(action.numpy())
          raw_reward = float(info.get('score_reward', reward))
          if done:
            with timer_cls('actor/elapsed_env_reset_s', 10):
              observation = env.reset()
          actor_step += 1
      except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
        logging.exception(e)
        env.close()
Exemplo n.º 4
0
def main(_):
  validate_config()
  settings = utils.init_learner(FLAGS.num_training_tpus)
  strategy, inference_devices, training_strategy, encode, decode = settings
  # Environment specification.
  env = config.create_environment(0)
  env_output_specs = utils.EnvOutput(
      tf.TensorSpec([], tf.float32, 'reward'),
      tf.TensorSpec([], tf.bool, 'done'),
      tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                    'observation'),
  )
  action_specs = tf.TensorSpec([], tf.int32, 'action')
  num_actions = env.action_space.n
  agent_input_specs = (action_specs, env_output_specs)

  # Initialize agent and variables.
  agent = config.create_agent(env_output_specs, num_actions)
  initial_agent_state = agent.initial_state(1)
  agent_state_specs = tf.nest.map_structure(
      lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
  input_ = tf.nest.map_structure(
      lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
  input_ = encode(input_)

  with strategy.scope():
    @tf.function
    def create_variables(*args):
      return agent(*decode(args))

    initial_agent_output, _ = create_variables(input_, initial_agent_state)
    # Create optimizer.
    iter_frame_ratio = (
        FLAGS.batch_size * FLAGS.unroll_length * FLAGS.num_action_repeats)
    final_iteration = int(
        math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
    optimizer, learning_rate_fn = config.create_optimizer(final_iteration)


    iterations = optimizer.iterations
    optimizer._create_hypers()  
    optimizer._create_slots(agent.trainable_variables)  

    # ON_READ causes the replicated variable to act as independent variables for
    # each replica.
    temp_grads = [
        tf.Variable(tf.zeros_like(v), trainable=False,
                    synchronization=tf.VariableSynchronization.ON_READ)
        for v in agent.trainable_variables
    ]

  @tf.function
  def minimize(iterator):
    data = next(iterator)

    def compute_gradients(args):
      args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data))
      with tf.GradientTape() as tape:
        loss = compute_loss(agent, *args)
      grads = tape.gradient(loss, agent.trainable_variables)
      for t, g in zip(temp_grads, grads):
        t.assign(g)
      return loss

    loss = training_strategy.experimental_run_v2(compute_gradients, (data,))
    loss = training_strategy.experimental_local_results(loss)[0]

    def apply_gradients(_):
      optimizer.apply_gradients(zip(temp_grads, agent.trainable_variables))

    strategy.experimental_run_v2(apply_gradients, (loss,))

  agent_output_specs = tf.nest.map_structure(
      lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)
  # Logging.
  summary_writer = tf.summary.create_file_writer(
      FLAGS.logdir, flush_millis=20000, max_queue=1000)

  # Setup checkpointing and restore checkpoint.
  ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
  manager = tf.train.CheckpointManager(
      ckpt, FLAGS.logdir, max_to_keep=1, keep_checkpoint_every_n_hours=6)
  last_ckpt_time = 0  # Force checkpointing of the initial model.
  if manager.latest_checkpoint:
    logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
    ckpt.restore(manager.latest_checkpoint).assert_consumed()
    last_ckpt_time = time.time()

  server = grpc.Server([FLAGS.server_address])

  store = utils.UnrollStore(
      FLAGS.num_actors, FLAGS.unroll_length,
      (action_specs, env_output_specs, agent_output_specs))
  actor_run_ids = utils.Aggregator(FLAGS.num_actors,
                                   tf.TensorSpec([], tf.int64, 'run_ids'))
  info_specs = (
      tf.TensorSpec([], tf.int64, 'episode_num_frames'),
      tf.TensorSpec([], tf.float32, 'episode_returns'),
      tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
  )
  actor_infos = utils.Aggregator(FLAGS.num_actors, info_specs)

  # First agent state in an unroll.
  first_agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs)

  # Current agent state and action.
  agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs)
  actions = utils.Aggregator(FLAGS.num_actors, action_specs)

  unroll_specs = Unroll(agent_state_specs, *store.unroll_specs)
  unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs)
  info_queue = utils.StructuredFIFOQueue(-1, info_specs)

  def add_batch_size(ts):
    return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                         ts.dtype, ts.name)

  inference_iteration = tf.Variable(-1)
  inference_specs = (
      tf.TensorSpec([], tf.int32, 'actor_id'),
      tf.TensorSpec([], tf.int64, 'run_id'),
      env_output_specs,
      tf.TensorSpec([], tf.float32, 'raw_reward'),
  )
  inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)
  @tf.function(input_signature=inference_specs)
  def inference(actor_ids, run_ids, env_outputs, raw_rewards):
    # Reset the actors that had their first run or crashed.
    previous_run_ids = actor_run_ids.read(actor_ids)
    actor_run_ids.replace(actor_ids, run_ids)
    reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
    actors_needing_reset = tf.gather(actor_ids, reset_indices)
    if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
      tf.print('Actor ids needing reset:', actors_needing_reset)
    actor_infos.reset(actors_needing_reset)
    store.reset(actors_needing_reset)
    initial_agent_states = agent.initial_state(
        tf.shape(actors_needing_reset)[0])
    first_agent_states.replace(actors_needing_reset, initial_agent_states)
    agent_states.replace(actors_needing_reset, initial_agent_states)
    actions.reset(actors_needing_reset)

    # Update steps and return.
    actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards))
    done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0])
    info_queue.enqueue_many(actor_infos.read(done_ids))
    actor_infos.reset(done_ids)
    actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.))

    # Inference.
    prev_actions = actions.read(actor_ids)
    input_ = encode((prev_actions, env_outputs))
    prev_agent_states = agent_states.read(actor_ids)
    def make_inference_fn(inference_device):
      def device_specific_inference_fn():
        with tf.device(inference_device):
          @tf.function
          def agent_inference(*args):
            return agent(*decode(args))

          return agent_inference(input_, prev_agent_states)

      return device_specific_inference_fn

    # Distribute the inference calls among the inference cores.
    branch_index = inference_iteration.assign_add(1) % len(inference_devices)
    agent_outputs, curr_agent_states = tf.switch_case(branch_index, {
        i: make_inference_fn(inference_device)
        for i, inference_device in enumerate(inference_devices)
    })

    # Append the latest outputs to the unroll and insert completed unrolls in
    # queue.
    completed_ids, unrolls = store.append(
        actor_ids, (prev_actions, env_outputs, agent_outputs))
    unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
    unroll_queue.enqueue_many(unrolls)
    first_agent_states.replace(completed_ids,
                               agent_states.read(completed_ids))

    # Update current state.
    agent_states.replace(actor_ids, curr_agent_states)
    actions.replace(actor_ids, agent_outputs.action)

    # Return environment actions to actors.
    return agent_outputs.action

  with strategy.scope():
    server.bind(inference, batched=True)
  server.start()

  def dequeue(ctx):
    # Create batch (time major).
    actor_outputs = tf.nest.map_structure(lambda *args: tf.stack(args), *[
        unroll_queue.dequeue()
        for i in range(ctx.get_per_replica_batch_size(FLAGS.batch_size))
    ])
    actor_outputs = actor_outputs._replace(
        prev_actions=utils.make_time_major(actor_outputs.prev_actions),
        env_outputs=utils.make_time_major(actor_outputs.env_outputs),
        agent_outputs=utils.make_time_major(actor_outputs.agent_outputs))
    actor_outputs = actor_outputs._replace(
        env_outputs=encode(actor_outputs.env_outputs))
    # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
    # repack.
    return tf.nest.flatten(actor_outputs)

  def dataset_fn(ctx):
    dataset = tf.data.Dataset.from_tensors(0).repeat(None)
    return dataset.map(lambda _: dequeue(ctx),
                       num_parallel_calls=ctx.num_replicas_in_sync)

  dataset = training_strategy.experimental_distribute_datasets_from_function(
      dataset_fn)
  it = iter(dataset)

  # Execute learning and track performance.
  with summary_writer.as_default(), \
    concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
    log_future = executor.submit(lambda: None)  # No-op future.
    last_num_env_frames = iterations * iter_frame_ratio
    last_log_time = time.time()
    while iterations < final_iteration:
      num_env_frames = iterations * iter_frame_ratio
      tf.summary.experimental.set_step(num_env_frames)

      # Save checkpoint.
      current_time = time.time()
      if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
        manager.save()
        last_ckpt_time = current_time

      def log(num_env_frames):
        """Logs actor summaries."""
        summary_writer.set_as_default()
        tf.summary.experimental.set_step(num_env_frames)
        episode_num_frames, episode_returns, episode_raw_returns = (
            info_queue.dequeue_many(info_queue.size()))
        for n, r, s in zip(episode_num_frames, episode_returns,
                           episode_raw_returns):
          logging.info('Return: %f Frames: %i', r, n)
          tf.summary.scalar('episode_return', r)
          tf.summary.scalar('episode_raw_return', s)
          tf.summary.scalar('num_episode_frames', n)
      log_future.result()  # Raise exception if any occurred in logging.
      log_future = executor.submit(log, num_env_frames)

      if current_time - last_log_time >= 120:
        df = tf.cast(num_env_frames - last_num_env_frames, tf.float32)
        dt = time.time() - last_log_time
        tf.summary.scalar('num_environment_frames/sec', df / dt)
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))

        last_num_env_frames, last_log_time = num_env_frames, time.time()

      minimize(it)

  manager.save()
  server.shutdown()
  unroll_queue.close()
Exemplo n.º 5
0
def main(_):
  validate_config()
  settings = utils.init_learner(FLAGS.num_training_tpus)
  strategy, inference_devices, training_strategy, encode, decode = settings
  # Environment specification.
  env = config.create_environment(0)
  env_output_specs = utils.EnvOutput(
      tf.TensorSpec([], tf.float32, 'reward'),
      tf.TensorSpec([], tf.bool, 'done'),
      tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                    'observation'),
  )
  action_specs = tf.TensorSpec([], tf.int32, 'action')
  num_actions = env.action_space.n
  agent_input_specs = (action_specs, env_output_specs)

  # Initialize agent and variables.
  agent = config.create_agent(env_output_specs, num_actions)
  target_agent = config.create_agent(env_output_specs, num_actions)
  initial_agent_state = agent.initial_state(1)
  agent_state_specs = tf.nest.map_structure(
      lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
  input_ = tf.nest.map_structure(
      lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
  input_ = encode(input_)

  with strategy.scope():

    @tf.function
    def create_variables(*args):
      return agent(*decode(args))

    @tf.function
    def create_target_agent_variables(*args):
      return target_agent(*decode(args))

    # The first call to Keras models to create varibales for agent and target.
    initial_agent_output, _ = create_variables(input_, initial_agent_state)
    create_target_agent_variables(input_, initial_agent_state)

    @tf.function
    def update_target_agent():
      """Synchronizes training and target agent variables."""
      variables = agent.trainable_variables
      target_variables = target_agent.trainable_variables
      assert len(target_variables) == len(variables), (
          'Mismatch in number of net tensors: {} != {}'.format(
              len(target_variables), len(variables)))
      for target_var, source_var in zip(target_variables, variables):
        target_var.assign(source_var)

    # Create optimizer.
    iter_frame_ratio = (
        get_replay_insertion_batch_size() *
        FLAGS.unroll_length * FLAGS.num_action_repeats)
    final_iteration = int(
        math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
    optimizer, learning_rate_fn = config.create_optimizer(final_iteration)


    iterations = optimizer.iterations
    optimizer._create_hypers()  
    optimizer._create_slots(agent.trainable_variables)  

    # ON_READ causes the replicated variable to act as independent variables for
    # each replica.
    temp_grads = [
        tf.Variable(tf.zeros_like(v), trainable=False,
                    synchronization=tf.VariableSynchronization.ON_READ)
        for v in agent.trainable_variables
    ]

  @tf.function
  def minimize(iterator):
    """Computes and applies gradients.

    Args:
      iterator: An iterator of distributed dataset that produces `PerReplica`.

    Returns:
      A tuple:
        - priorities, the new priorities. Shape <float32>[batch_size].
        - indices, the indices for updating priorities. Shape
        <int32>[batch_size].
        - gradient_norm_before_clip, a scalar.
    """
    data = next(iterator)

    def compute_gradients(args):
      """A function to pass to `Strategy` for gradient computation."""
      args = decode(args, data)
      args = tf.nest.pack_sequence_as(SampledUnrolls(unroll_specs, 0, 0), args)
      with tf.GradientTape() as tape:
        # loss: [batch_size]
        # priorities: [batch_size]
        loss, priorities = compute_loss_and_priorities(
            agent,
            target_agent,
            args.unrolls.agent_state,
            args.unrolls.prev_actions,
            args.unrolls.env_outputs,
            args.unrolls.agent_outputs,
            gamma=FLAGS.discounting,
            burn_in=FLAGS.burn_in)
        loss = tf.reduce_mean(loss * args.importance_weights)
      grads = tape.gradient(loss, agent.trainable_variables)
      gradient_norm_before_clip = tf.linalg.global_norm(grads)
      if FLAGS.clip_norm:
        grads, _ = tf.clip_by_global_norm(
            grads, FLAGS.clip_norm, use_norm=gradient_norm_before_clip)

      for t, g in zip(temp_grads, grads):
        t.assign(g)

      return loss, priorities, args.indices, gradient_norm_before_clip

    loss, priorities, indices, gradient_norm_before_clip = (
        training_strategy.experimental_run_v2(compute_gradients, (data,)))
    loss = training_strategy.experimental_local_results(loss)[0]

    def apply_gradients(loss):
      optimizer.apply_gradients(zip(temp_grads, agent.trainable_variables))
      return loss


    loss = strategy.experimental_run_v2(apply_gradients, (loss,))

    # convert PerReplica to a Tensor
    if not isinstance(priorities, tf.Tensor):

      priorities = tf.reshape(tf.stack(priorities.values), [-1])
      indices = tf.reshape(tf.stack(indices.values), [-1])
      gradient_norm_before_clip = tf.reshape(
          tf.stack(gradient_norm_before_clip.values), [-1])
      gradient_norm_before_clip = tf.reduce_max(gradient_norm_before_clip)

    return loss, priorities, indices, gradient_norm_before_clip

  agent_output_specs = tf.nest.map_structure(
      lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)
  # Logging.
  summary_writer = tf.summary.create_file_writer(
      FLAGS.logdir, flush_millis=20000, max_queue=1000)

  # Setup checkpointing and restore checkpoint.

  ckpt = tf.train.Checkpoint(
      agent=agent, target_agent=target_agent, optimizer=optimizer)
  manager = tf.train.CheckpointManager(
      ckpt, FLAGS.logdir, max_to_keep=1, keep_checkpoint_every_n_hours=6)
  last_ckpt_time = 0  # Force checkpointing of the initial model.
  if manager.latest_checkpoint:
    logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
    ckpt.restore(manager.latest_checkpoint).assert_consumed()
    last_ckpt_time = time.time()

  server = grpc.Server([FLAGS.server_address])

  # Buffer of incomplete unrolls. Filled during inference with new transitions.
  # This only contains data from training actors.
  store = utils.UnrollStore(
      get_num_training_actors(), FLAGS.unroll_length,
      (action_specs, env_output_specs, agent_output_specs),
      num_overlapping_steps=FLAGS.burn_in)
  actor_run_ids = utils.Aggregator(FLAGS.num_actors,
                                   tf.TensorSpec([], tf.int64, 'run_ids'))
  info_specs = (
      tf.TensorSpec([], tf.int64, 'episode_num_frames'),
      tf.TensorSpec([], tf.float32, 'episode_returns'),
      tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
  )
  actor_infos = utils.Aggregator(FLAGS.num_actors, info_specs)

  # First agent state in an unroll.
  first_agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs)

  # Current agent state and action.
  agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs)
  actions = utils.Aggregator(FLAGS.num_actors, action_specs)

  unroll_specs = Unroll(agent_state_specs,
                        tf.TensorSpec([], tf.float32, 'priority'),
                        *store.unroll_specs)
  # Queue of complete unrolls. Filled by the inference threads, and consumed by
  # the tf.data.Dataset thread.
  unroll_queue = utils.StructuredFIFOQueue(
      FLAGS.unroll_queue_max_size, unroll_specs)
  episode_info_specs = EpisodeInfo(*(
      info_specs + (tf.TensorSpec([], tf.int32, 'actor_ids'),)))
  info_queue = utils.StructuredFIFOQueue(-1, episode_info_specs)

  replay_buffer = utils.PrioritizedReplay(FLAGS.replay_buffer_size,
                                          unroll_specs,
                                          FLAGS.importance_sampling_exponent)

  def add_batch_size(ts):
    return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                         ts.dtype, ts.name)
  inference_iteration = tf.Variable(-1)
  inference_specs = (
      tf.TensorSpec([], tf.int32, 'actor_id'),
      tf.TensorSpec([], tf.int64, 'run_id'),
      env_output_specs,
      tf.TensorSpec([], tf.float32, 'raw_reward'),
  )
  inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)
  @tf.function(input_signature=inference_specs)
  def inference(actor_ids, run_ids, env_outputs, raw_rewards):
    """Agent inference.

    This evaluates the agent policy on the provided environment data (reward,
    done, observation), and store appropriate data to feed the main training
    loop.

    Args:
      actor_ids: <int32>[inference_batch_size], the actor task IDs (in range
        [0, num_tasks)).
      run_ids: <int64>[inference_batch_size], the actor run IDs. Actor
        generate a random int64 run id at startup, so this can be used to detect
        the actors jobs that restarted.
      env_outputs: Follows env_output_specs, but with the inference_batch_size
        added as first dimension. These are the actual environment outputs
        (reward, done, observation).
      raw_rewards: <float32>[inference_batch_size], representing the raw reward
        of each step.

    Returns:
      A tensor <int32>[inference_batch_size] with one action for each actor.
    """
    # Reset the actors that had their first run or crashed.
    previous_run_ids = actor_run_ids.read(actor_ids)
    actor_run_ids.replace(actor_ids, run_ids)
    reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
    actors_needing_reset = tf.gather(actor_ids, reset_indices)
    if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
      tf.print('Actors needing reset:', actors_needing_reset)
    actor_infos.reset(actors_needing_reset)
    store.reset(tf.gather(
        actors_needing_reset,
        tf.where(is_training_actor(actors_needing_reset))[:, 0]))
    initial_agent_states = agent.initial_state(
        tf.shape(actors_needing_reset)[0])
    first_agent_states.replace(actors_needing_reset, initial_agent_states)
    agent_states.replace(actors_needing_reset, initial_agent_states)
    actions.reset(actors_needing_reset)

    # Update steps and return.
    actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards))
    done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0])
    done_episodes_info = actor_infos.read(done_ids)
    info_queue.enqueue_many(EpisodeInfo(*(done_episodes_info + (done_ids,))))
    actor_infos.reset(done_ids)
    actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.))

    # Inference.
    prev_actions = actions.read(actor_ids)
    input_ = encode((prev_actions, env_outputs))
    prev_agent_states = agent_states.read(actor_ids)
    def make_inference_fn(inference_device):
      def device_specific_inference_fn():
        with tf.device(inference_device):
          @tf.function
          def agent_inference(*args):
            return agent(*decode(args))

          return agent_inference(input_, prev_agent_states)

      return device_specific_inference_fn

    # Distribute the inference calls among the inference cores.
    branch_index = inference_iteration.assign_add(1) % len(inference_devices)
    agent_outputs, curr_agent_states = tf.switch_case(branch_index, {
        i: make_inference_fn(inference_device)
        for i, inference_device in enumerate(inference_devices)
    })

    agent_outputs = agent_outputs._replace(
        action=apply_epsilon_greedy(
            agent_outputs.action, actor_ids,
            get_num_training_actors(),
            FLAGS.num_eval_actors, FLAGS.eval_epsilon, num_actions))

    # Append the latest outputs to the unroll, only for experience coming from
    # training actors (IDs < num_training_actors), and insert completed unrolls
    # in queue.
    # <int64>[num_training_actors]
    training_indices = tf.where(is_training_actor(actor_ids))[:, 0]
    training_actor_ids = tf.gather(actor_ids, training_indices)
    training_prev_actions, training_env_outputs, training_agent_outputs = (
        tf.nest.map_structure(lambda s: tf.gather(s, training_indices),
                              (prev_actions, env_outputs, agent_outputs)))

    append_to_store = (
        training_prev_actions, training_env_outputs, training_agent_outputs)
    completed_ids, completed_unrolls = store.append(
        training_actor_ids, append_to_store)
    _, unrolled_env_outputs, unrolled_agent_outputs = completed_unrolls
    unrolled_agent_states = first_agent_states.read(completed_ids)

    # Only use the suffix of the unrolls that is actually used for training. The
    # prefix is only used for burn-in of agent state at training time.
    _, agent_outputs_suffix = utils.split_structure(
        utils.make_time_major(unrolled_agent_outputs), FLAGS.burn_in)
    _, env_outputs_suffix = utils.split_structure(

        utils.make_time_major(unrolled_env_outputs), FLAGS.burn_in)
    _, initial_priorities = compute_loss_and_priorities_from_agent_outputs(
        # We don't use the outputs from a separated target network for computing
        # initial priorities.
        agent_outputs_suffix,
        agent_outputs_suffix,
        env_outputs_suffix,
        agent_outputs_suffix,
        gamma=FLAGS.discounting)

    unrolls = Unroll(unrolled_agent_states, initial_priorities,
                     *completed_unrolls)
    unroll_queue.enqueue_many(unrolls)
    first_agent_states.replace(completed_ids,
                               agent_states.read(completed_ids))

    # Update current state.
    agent_states.replace(actor_ids, curr_agent_states)
    actions.replace(actor_ids, agent_outputs.action)

    # Return environment actions to actors.
    return agent_outputs.action

  with strategy.scope():
    server.bind(inference, batched=True)
  server.start()

  # Execute learning and track performance.
  with summary_writer.as_default(), \
    concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
    log_future = executor.submit(lambda: None)  # No-op future.
    tf.summary.experimental.set_step(iterations * iter_frame_ratio)
    dataset = create_dataset(unroll_queue, replay_buffer, training_strategy,
                             FLAGS.batch_size, FLAGS.priority_exponent, encode)
    it = iter(dataset)

    last_num_env_frames = iterations * iter_frame_ratio
    last_log_time = time.time()
    max_gradient_norm_before_clip = 0.
    while iterations < final_iteration:
      num_env_frames = iterations * iter_frame_ratio
      tf.summary.experimental.set_step(num_env_frames)

      if iterations.numpy() % FLAGS.update_target_every_n_step == 0:
        update_target_agent()

      # Save checkpoint.
      current_time = time.time()
      if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
        manager.save()
        last_ckpt_time = current_time

      def log(num_env_frames):
        """Logs actor summaries."""
        summary_writer.set_as_default()
        tf.summary.experimental.set_step(num_env_frames)
        episode_info = info_queue.dequeue_many(info_queue.size())
        for n, r, _, actor_id in zip(*episode_info):
          is_training = is_training_actor(actor_id)
          logging.info(
              'Return: %f Frames: %i Actor id: %i (%s) Iteration: %i',
              r, n, actor_id,
              'training' if is_training else 'eval',
              iterations.numpy())
          if not is_training:
            tf.summary.scalar('eval/episode_return', r)
            tf.summary.scalar('eval/episode_frames', n)
      log_future.result()  # Raise exception if any occurred in logging.
      log_future = executor.submit(log, num_env_frames)

      _, priorities, indices, gradient_norm = minimize(it)

      replay_buffer.update_priorities(indices, priorities)
      # Max of gradient norms (before clipping) since last tf.summary export.
      max_gradient_norm_before_clip = max(gradient_norm.numpy(),
                                          max_gradient_norm_before_clip)
      if current_time - last_log_time >= 120:
        df = tf.cast(num_env_frames - last_num_env_frames, tf.float32)
        dt = time.time() - last_log_time
        tf.summary.scalar('num_environment_frames/sec (actors)', df / dt)
        tf.summary.scalar('num_environment_frames/sec (learner)',
                          df / dt * FLAGS.replay_ratio)

        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        tf.summary.scalar('replay_buffer_num_inserted',
                          replay_buffer.num_inserted)
        tf.summary.scalar('unroll_queue_size', unroll_queue.size())

        last_num_env_frames, last_log_time = num_env_frames, time.time()
        tf.summary.histogram('updated_priorities', priorities)
        tf.summary.scalar('max_gradient_norm_before_clip',
                          max_gradient_norm_before_clip)
        max_gradient_norm_before_clip = 0.

  manager.save()
  server.shutdown()
  unroll_queue.close()