Example #1
0
  def test_logger(self):
    logger = utils.ProgressLogger()
    logger.start()
    logger._log()

    @tf.function(input_signature=(tf.TensorSpec([], tf.int32, 'value'),))
    def log_something(value):
      session = logger.log_session()
      logger.log(session, 'value_1', value)
      logger.log(session, 'value_2', value + 1)
      logger.step_end(session)

    log_something(tf.constant(10))
    logger._log()
    self.assertAllEqual(logger.ready_values.read_value(), tf.constant([10, 11]))
    log_something(tf.constant(15))
    self.assertAllEqual(logger.ready_values.read_value(), tf.constant([15, 16]))
    logger._log()
    logger.shutdown()
Example #2
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0, FLAGS)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.nest.map_structure(
            lambda s: tf.TensorSpec(s.shape, s.dtype, 'observation'),
            env.observation_space.__dict__.get('spaces',
                                               env.observation_space)),
        tf.TensorSpec([], tf.bool, 'abandoned'),
        tf.TensorSpec([], tf.int32, 'episode_step'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    target_agent = create_agent_fn(env.action_space, env.observation_space,
                                   parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    agent_input_specs = (action_specs, env_output_specs)

    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1, 1] + list(s.shape), s.dtype), agent_input_specs)
    input_no_time = tf.nest.map_structure(lambda t: t[0], input_)

    input_ = encode(input_ + (initial_agent_state, ))
    input_no_time = encode(input_no_time + (initial_agent_state, ))

    with strategy.scope():
        # Initialize variables
        def initialize_agent_variables(agent):
            if not hasattr(agent, 'entropy_cost'):
                mul = FLAGS.entropy_cost_adjustment_speed
                agent.entropy_cost_param = tf.Variable(
                    tf.math.log(FLAGS.entropy_cost) / mul,
                    # Without the constraint, the param gradient may get rounded to 0
                    # for very small values.
                    constraint=lambda v: tf.clip_by_value(
                        v, -20 / mul, 20 / mul),
                    trainable=True,
                    dtype=tf.float32)
                agent.entropy_cost = lambda: tf.exp(mul * agent.
                                                    entropy_cost_param)

            @tf.function
            def create_variables():
                return [
                    agent.get_action(*decode(input_no_time)),
                    agent.get_V(*decode(input_)),
                    agent.get_Q(*decode(input_), action=decode(input_[0]))
                ]

            create_variables()

        initialize_agent_variables(agent)
        initialize_agent_variables(target_agent)

        # Target network update
        @tf.function
        def update_target_agent(polyak):
            """Synchronizes training and target agent variables."""
            variables = agent.variables
            target_variables = target_agent.variables
            assert len(target_variables) == len(variables), (
                'Mismatch in number of net tensors: {} != {}'.format(
                    len(target_variables), len(variables)))
            for target_var, source_var in zip(target_variables, variables):
                target_var.assign(polyak * target_var +
                                  (1. - polyak) * source_var)

        update_target_agent(polyak=0.)  # copy weights

        # Create optimizer.
        iter_frame_ratio = (
            get_replay_insertion_batch_size(per_replica=False) *
            (FLAGS.her_window_length or FLAGS.unroll_length) *
            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(logger,
                                          parametric_action_distribution,
                                          agent, target_agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)
            return loss, logs

        loss, logs = training_strategy.run(compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]

        def apply_gradients(_):
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.run(apply_gradients, (loss, ))

        getattr(
            agent, 'end_of_training_step_callback',
            lambda: logging.info('end_of_training_step_callback not found'))()

        logger.step_end(logs, training_strategy, iter_frame_ratio)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent,
                               target_agent=target_agent,
                               optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=1,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)
    logger = utils.ProgressLogger(summary_writer=summary_writer,
                                  starting_step=iterations * iter_frame_ratio)

    server = grpc.Server([FLAGS.server_address])

    store = utils.UnrollStore(FLAGS.num_envs, FLAGS.her_window_length
                              or FLAGS.unroll_length,
                              (action_specs, env_output_specs, action_specs))
    env_run_ids = utils.Aggregator(FLAGS.num_envs,
                                   tf.TensorSpec([], tf.int64, 'run_ids'))
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )
    env_infos = utils.Aggregator(FLAGS.num_envs, info_specs, 'env_infos')

    # First agent state in an unroll.
    first_agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                          'first_agent_states')

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_envs, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_envs, action_specs, 'actions')

    unroll_specs = Unroll(agent_state_specs, *store.unroll_specs)
    unroll_queue = utils.StructuredFIFOQueue(FLAGS.unroll_queue_max_size,
                                             unroll_specs)
    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    if FLAGS.her_window_length:
        replay_buffer = utils.HindsightExperienceReplay(
            FLAGS.replay_buffer_size,
            unroll_specs,
            compute_reward_fn=env.compute_reward,
            unroll_length=FLAGS.unroll_length,
            importance_sampling_exponent=0.,
            substitution_probability=FLAGS.her_substitution_probability)
    else:
        replay_buffer = utils.PrioritizedReplay(
            FLAGS.replay_buffer_size,
            unroll_specs,
            importance_sampling_exponent=0.)

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1, dtype=tf.int64)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'env_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(env_ids, run_ids, env_outputs, raw_rewards):
        # Reset the environment that had their first run or crashed.
        previous_run_ids = env_run_ids.read(env_ids)
        env_run_ids.replace(env_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        envs_needing_reset = tf.gather(env_ids, reset_indices)
        if tf.not_equal(tf.shape(envs_needing_reset)[0], 0):
            tf.print('Environment ids needing reset:', envs_needing_reset)
        env_infos.reset(envs_needing_reset)
        store.reset(envs_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(envs_needing_reset)[0])
        first_agent_states.replace(envs_needing_reset, initial_agent_states)
        agent_states.replace(envs_needing_reset, initial_agent_states)
        actions.reset(envs_needing_reset)

        tf.debugging.assert_non_positive(
            tf.cast(env_outputs.abandoned, tf.int32),
            'Abandoned done states are not supported in SAC.')

        # Update steps and return.
        env_infos.add(env_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(env_ids, tf.where(env_outputs.done)[:, 0])
        info_queue.enqueue_many(env_infos.read(done_ids))
        env_infos.reset(done_ids)
        env_infos.add(env_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = actions.read(env_ids)
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(env_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args), is_training=False)

                    return agent_inference(*input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = tf.cast(
            inference_iteration.assign_add(1) % len(inference_devices),
            tf.int32)
        agent_actions, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Append the latest outputs to the unroll and insert completed unrolls in
        # queue.
        completed_ids, unrolls = store.append(
            env_ids, (prev_actions, env_outputs, agent_actions))
        unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(env_ids, curr_agent_states)
        actions.replace(env_ids, agent_actions)

        # Return environment actions to environments.
        return agent_actions

    with strategy.scope():
        server.bind(inference)
    server.start()

    dataset = create_dataset(unroll_queue, replay_buffer, training_strategy,
                             FLAGS.batch_size, encode)
    it = iter(dataset)

    def additional_logs():
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        tf.summary.scalar('buffer/unrolls_inserted',
                          replay_buffer.num_inserted)
        # log data from info_queue
        n_episodes = info_queue.size()
        n_episodes -= n_episodes % FLAGS.log_episode_frequency
        if tf.not_equal(n_episodes, 0):
            episode_stats = info_queue.dequeue_many(n_episodes)
            episode_keys = [
                'episode_num_frames', 'episode_return', 'episode_raw_return'
            ]
            for key, values in zip(episode_keys, episode_stats):
                for value in tf.split(
                        values,
                        values.shape[0] // FLAGS.log_episode_frequency):
                    tf.summary.scalar(key, tf.reduce_mean(value))

            for (frames, ep_return, raw_return) in zip(*episode_stats):
                logging.info('Return: %f Raw return: %f Frames: %i', ep_return,
                             raw_return, frames)

    logger.start(additional_logs)
    # Execute learning.
    while iterations < final_iteration:
        if iterations.numpy() % FLAGS.update_target_every_n_step == 0:
            update_target_agent(FLAGS.polyak)
        # Save checkpoint.
        current_time = time.time()
        if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
            manager.save()
            # Apart from checkpointing, we also save the full model (including
            # the graph). This way we can load it after the code/parameters changed.
            tf.saved_model.save(agent, os.path.join(FLAGS.logdir,
                                                    'saved_model'))
            last_ckpt_time = current_time
        minimize(it)
    logger.shutdown()
    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    server.shutdown()
    unroll_queue.close()
Example #3
0
def learner_loop(create_env_fn, create_agent_fn, create_optimizer_fn):
    """Main learner loop.

  Args:
    create_env_fn: Callable that must return a newly created environment. The
      callable takes the task ID as argument - an arbitrary task ID of 0 will be
      passed by the learner. The returned environment should follow GYM's API.
      It is only used for infering tensor shapes. This environment will not be
      used to generate experience.
    create_agent_fn: Function that must create a new tf.Module with the neural
      network that outputs actions and new agent state given the environment
      observations and previous agent state. See dmlab.agents.ImpalaDeep for an
      example. The factory function takes as input the environment action and
      observation spaces and a parametric distribution over actions.
    create_optimizer_fn: Function that takes the final iteration as argument
      and must return a tf.keras.optimizers.Optimizer and a
      tf.keras.optimizers.schedules.LearningRateSchedule.
  """
    logging.info('Starting learner loop')
    validate_config()
    settings = utils.init_learner(FLAGS.num_training_tpus)
    strategy, inference_devices, training_strategy, encode, decode = settings
    env = create_env_fn(0)
    parametric_action_distribution = get_parametric_distribution_for_action_space(
        env.action_space)
    env_output_specs = utils.EnvOutput(
        tf.TensorSpec([], tf.float32, 'reward'),
        tf.TensorSpec([], tf.bool, 'done'),
        tf.TensorSpec(env.observation_space.shape, env.observation_space.dtype,
                      'observation'),
    )
    action_specs = tf.TensorSpec(env.action_space.shape,
                                 env.action_space.dtype, 'action')
    agent_input_specs = (action_specs, env_output_specs)

    # Initialize agent and variables.
    agent = create_agent_fn(env.action_space, env.observation_space,
                            parametric_action_distribution)
    initial_agent_state = agent.initial_state(1)
    agent_state_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_state)
    input_ = tf.nest.map_structure(
        lambda s: tf.zeros([1] + list(s.shape), s.dtype), agent_input_specs)
    input_ = encode(input_)

    with strategy.scope():

        @tf.function
        def create_variables(*args):
            return agent.get_action(*decode(args))

        initial_agent_output, _ = create_variables(*input_,
                                                   initial_agent_state)

        if not hasattr(agent, 'entropy_cost'):
            mul = FLAGS.entropy_cost_adjustment_speed
            agent.entropy_cost_param = tf.Variable(
                tf.math.log(FLAGS.entropy_cost) / mul,
                # Without the constraint, the param gradient may get rounded to 0
                # for very small values.
                constraint=lambda v: tf.clip_by_value(v, -20 / mul, 20 / mul),
                trainable=True,
                dtype=tf.float32)
            agent.entropy_cost = lambda: tf.exp(mul * agent.entropy_cost_param)
        # Create optimizer.
        iter_frame_ratio = (FLAGS.batch_size * FLAGS.unroll_length *
                            FLAGS.num_action_repeats)
        final_iteration = int(
            math.ceil(FLAGS.total_environment_frames / iter_frame_ratio))
        optimizer, learning_rate_fn = create_optimizer_fn(final_iteration)

        iterations = optimizer.iterations
        optimizer._create_hypers()
        optimizer._create_slots(agent.trainable_variables)

        # ON_READ causes the replicated variable to act as independent variables for
        # each replica.
        temp_grads = [
            tf.Variable(tf.zeros_like(v),
                        trainable=False,
                        synchronization=tf.VariableSynchronization.ON_READ)
            for v in agent.trainable_variables
        ]

    @tf.function
    def minimize(iterator):
        data = next(iterator)

        def compute_gradients(args):
            args = tf.nest.pack_sequence_as(unroll_specs, decode(args, data))
            with tf.GradientTape() as tape:
                loss, logs = compute_loss(logger,
                                          parametric_action_distribution,
                                          agent, *args)
            grads = tape.gradient(loss, agent.trainable_variables)
            for t, g in zip(temp_grads, grads):
                t.assign(g)

            with logger.summary_writer.as_default(), \
              tf.compat.v2.summary.record_if(tf.not_equal(iterations % 1000, 0)):
                for g, v in zip(temp_grads, agent.trainable_variables):
                    tf.summary.histogram(v.name, v, step=iterations)
                    tf.summary.histogram(v.name + "_grad", g, step=iterations)

            return loss, logs

        loss, logs = training_strategy.experimental_run_v2(
            compute_gradients, (data, ))
        loss = training_strategy.experimental_local_results(loss)[0]

        def apply_gradients(_):
            # clip_grads, _ = tf.clip_by_global_norm(temp_grads, 0.5)
            # clip_grads  = tf.clip_by_value(temp_grads, -0.2, 0.2)
            optimizer.apply_gradients(
                zip(temp_grads, agent.trainable_variables))

        strategy.experimental_run_v2(apply_gradients, (loss, ))

        try:
            agent.end_of_training_step_callback()
        except AttributeError:
            logging.info('end_of_episode_callback() not found')
        logger.step_end(logs, training_strategy, iter_frame_ratio)

    agent_output_specs = tf.nest.map_structure(
        lambda t: tf.TensorSpec(t.shape[1:], t.dtype), initial_agent_output)
    # Logging.
    summary_writer = tf.summary.create_file_writer(FLAGS.logdir,
                                                   flush_millis=20000,
                                                   max_queue=1000)
    logger = utils.ProgressLogger(summary_writer=summary_writer)

    with summary_writer.as_default():
        tf.summary.text("flags", str(FLAGS.flag_values_dict()), step=0)

    # Setup checkpointing and restore checkpoint.
    ckpt = tf.train.Checkpoint(agent=agent, optimizer=optimizer)
    if FLAGS.init_checkpoint is not None:
        tf.print('Loading initial checkpoint from %s...' %
                 FLAGS.init_checkpoint)
        ckpt.restore(FLAGS.init_checkpoint).assert_consumed()
    manager = tf.train.CheckpointManager(ckpt,
                                         FLAGS.logdir,
                                         max_to_keep=FLAGS.max_to_keep,
                                         keep_checkpoint_every_n_hours=6)
    last_ckpt_time = 0  # Force checkpointing of the initial model.
    if manager.latest_checkpoint:
        logging.info('Restoring checkpoint: %s', manager.latest_checkpoint)
        ckpt.restore(manager.latest_checkpoint).assert_consumed()
        last_ckpt_time = time.time()

    server = grpc.Server([FLAGS.server_address])

    store = utils.UnrollStore(
        FLAGS.num_actors, FLAGS.unroll_length,
        (action_specs, env_output_specs, agent_output_specs))
    actor_run_ids = utils.Aggregator(FLAGS.num_actors,
                                     tf.TensorSpec([], tf.int64, 'run_ids'))
    info_specs = (
        tf.TensorSpec([], tf.int64, 'episode_num_frames'),
        tf.TensorSpec([], tf.float32, 'episode_returns'),
        tf.TensorSpec([], tf.float32, 'episode_raw_returns'),
    )
    actor_infos = utils.Aggregator(FLAGS.num_actors, info_specs, 'actor_infos')

    # First agent state in an unroll.
    first_agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                          'first_agent_states')

    # Current agent state and action.
    agent_states = utils.Aggregator(FLAGS.num_actors, agent_state_specs,
                                    'agent_states')
    actions = utils.Aggregator(FLAGS.num_actors, action_specs, 'actions')

    unroll_specs = Unroll(agent_state_specs, *store.unroll_specs)
    unroll_queue = utils.StructuredFIFOQueue(1, unroll_specs)
    info_queue = utils.StructuredFIFOQueue(-1, info_specs)

    def add_batch_size(ts):
        return tf.TensorSpec([FLAGS.inference_batch_size] + list(ts.shape),
                             ts.dtype, ts.name)

    inference_iteration = tf.Variable(-1)
    inference_specs = (
        tf.TensorSpec([], tf.int32, 'actor_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs = tf.nest.map_structure(add_batch_size, inference_specs)

    @tf.function(input_signature=inference_specs)
    def inference(actor_ids, run_ids, env_outputs, raw_rewards):
        # Reset the actors that had their first run or crashed.
        previous_run_ids = actor_run_ids.read(actor_ids)
        actor_run_ids.replace(actor_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        actors_needing_reset = tf.gather(actor_ids, reset_indices)
        if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
            tf.print('Actor ids needing reset:', actors_needing_reset)
        actor_infos.reset(actors_needing_reset)
        store.reset(actors_needing_reset)
        initial_agent_states = agent.initial_state(
            tf.shape(actors_needing_reset)[0])
        first_agent_states.replace(actors_needing_reset, initial_agent_states)
        agent_states.replace(actors_needing_reset, initial_agent_states)
        actions.reset(actors_needing_reset)

        # Update steps and return.
        actor_infos.add(actor_ids, (0, env_outputs.reward, raw_rewards))
        done_ids = tf.gather(actor_ids, tf.where(env_outputs.done)[:, 0])
        info_queue.enqueue_many(actor_infos.read(done_ids))
        actor_infos.reset(done_ids)
        actor_infos.add(actor_ids, (FLAGS.num_action_repeats, 0., 0.))

        # Inference.
        prev_actions = parametric_action_distribution.postprocess(
            actions.read(actor_ids))
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(actor_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args),
                                     is_training=False,
                                     postprocess_action=False)

                    return agent_inference(*input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = inference_iteration.assign_add(1) % len(
            inference_devices)
        agent_outputs, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Append the latest outputs to the unroll and insert completed unrolls in
        # queue.
        completed_ids, unrolls = store.append(
            actor_ids, (prev_actions, env_outputs, agent_outputs))
        unrolls = Unroll(first_agent_states.read(completed_ids), *unrolls)
        unroll_queue.enqueue_many(unrolls)
        first_agent_states.replace(completed_ids,
                                   agent_states.read(completed_ids))

        # Update current state.
        agent_states.replace(actor_ids, curr_agent_states)
        actions.replace(actor_ids, agent_outputs.action)

        # Return environment actions to actors.
        return parametric_action_distribution.postprocess(agent_outputs.action)

#########################################################

    def add_batch_size_eval(ts):
        return tf.TensorSpec([1] + list(ts.shape), ts.dtype, ts.name)

    inference_iteration_eval = tf.Variable(-1)
    inference_specs_eval = (
        tf.TensorSpec([], tf.int32, 'actor_id'),
        tf.TensorSpec([], tf.int64, 'run_id'),
        env_output_specs,
        tf.TensorSpec([], tf.float32, 'raw_reward'),
    )
    inference_specs_eval = tf.nest.map_structure(add_batch_size_eval,
                                                 inference_specs_eval)

    @tf.function(input_signature=inference_specs_eval)
    def inference_eval(actor_ids, run_ids, env_outputs, raw_rewards):
        # Reset the actors that had their first run or crashed.
        previous_run_ids = actor_run_ids.read(actor_ids)
        actor_run_ids.replace(actor_ids, run_ids)
        reset_indices = tf.where(tf.not_equal(previous_run_ids, run_ids))[:, 0]
        actors_needing_reset = tf.gather(actor_ids, reset_indices)
        if tf.not_equal(tf.shape(actors_needing_reset)[0], 0):
            tf.print('Actor ids needing reset:', actors_needing_reset)

        initial_agent_states = agent.initial_state(
            tf.shape(actors_needing_reset)[0])
        agent_states.replace(actors_needing_reset, initial_agent_states)
        actions.reset(actors_needing_reset)

        # Inference.
        prev_actions = parametric_action_distribution.postprocess(
            actions.read(actor_ids))
        input_ = encode((prev_actions, env_outputs))
        prev_agent_states = agent_states.read(actor_ids)

        def make_inference_fn(inference_device):
            def device_specific_inference_fn():
                with tf.device(inference_device):

                    @tf.function
                    def agent_inference(*args):
                        return agent(*decode(args),
                                     is_training=False,
                                     postprocess_action=False)

                    return agent_inference(*input_, prev_agent_states)

            return device_specific_inference_fn

        # Distribute the inference calls among the inference cores.
        branch_index = inference_iteration_eval.assign_add(1) % len(
            inference_devices)
        agent_outputs, curr_agent_states = tf.switch_case(
            branch_index, {
                i: make_inference_fn(inference_device)
                for i, inference_device in enumerate(inference_devices)
            })

        # Update current state.
        agent_states.replace(actor_ids, curr_agent_states)
        actions.replace(actor_ids, agent_outputs.action)

        # Return environment actions to actors.
        return parametric_action_distribution.postprocess(agent_outputs.action)


#########################################################

    with strategy.scope():
        server.bind(inference, batched=True)
        server.bind(inference_eval, batched=True)
    server.start()

    def dequeue(ctx):
        # Create batch (time major).
        actor_outputs = tf.nest.map_structure(
            lambda *args: tf.stack(args), *[
                unroll_queue.dequeue() for i in range(
                    ctx.get_per_replica_batch_size(FLAGS.batch_size))
            ])
        actor_outputs = actor_outputs._replace(
            prev_actions=utils.make_time_major(actor_outputs.prev_actions),
            env_outputs=utils.make_time_major(actor_outputs.env_outputs),
            agent_outputs=utils.make_time_major(actor_outputs.agent_outputs))
        actor_outputs = actor_outputs._replace(
            env_outputs=encode(actor_outputs.env_outputs))
        # tf.data.Dataset treats list leafs as tensors, so we need to flatten and
        # repack.
        return tf.nest.flatten(actor_outputs)

    def dataset_fn(ctx):
        dataset = tf.data.Dataset.from_tensors(0).repeat(None)
        return dataset.map(lambda _: dequeue(ctx),
                           num_parallel_calls=ctx.num_replicas_in_sync)

    dataset = training_strategy.experimental_distribute_datasets_from_function(
        dataset_fn)
    it = iter(dataset)

    def additional_logs():
        tf.summary.scalar('learning_rate', learning_rate_fn(iterations))
        n_episodes = info_queue.size()
        n_episodes -= n_episodes % FLAGS.log_episode_frequency
        if tf.not_equal(n_episodes, 0):
            episode_stats = info_queue.dequeue_many(n_episodes)
            episode_keys = [
                'episode_num_frames', 'episode_return', 'episode_raw_return'
            ]
            for key, values in zip(episode_keys, episode_stats):
                for value in tf.split(
                        values,
                        values.shape[0] // FLAGS.log_episode_frequency):
                    tf.summary.scalar(key, tf.reduce_mean(value))

            for (frames, ep_return, raw_return) in zip(*episode_stats):
                logging.info('Return: %f Raw return: %f Frames: %i', ep_return,
                             raw_return, frames)

    logger.start(additional_logs)
    # Execute learning.
    while iterations < final_iteration:
        # Save checkpoint.
        current_time = time.time()

        if current_time - last_ckpt_time >= FLAGS.save_checkpoint_secs:
            manager.save()
            # Apart from checkpointing, we also save the full model (including
            # the graph). This way we can load it after the code/parameters changed.
            tf.saved_model.save(agent, os.path.join(FLAGS.logdir,
                                                    'saved_model'))
            last_ckpt_time = current_time
        minimize(it)
    logger.shutdown()
    manager.save()
    tf.saved_model.save(agent, os.path.join(FLAGS.logdir, 'saved_model'))
    server.shutdown()
    unroll_queue.close()
    def test_ppo_training_step(self, batch_mode, use_agent_state):
        action_space = gym.spaces.Box(low=-1,
                                      high=1,
                                      shape=[128],
                                      dtype=np.float32)
        distribution = (
            parametric_distribution.
            get_parametric_distribution_for_action_space(action_space))
        training_agent = continuous_control_agent.ContinuousControlAgent(
            distribution)
        virtual_bs = 32
        unroll_length = 5
        batches_per_step = 4
        done = tf.zeros([unroll_length, virtual_bs], dtype=tf.bool)
        prev_actions = tf.reshape(
            tf.stack([
                action_space.sample()
                for _ in range(unroll_length * virtual_bs)
            ]), [unroll_length, virtual_bs, -1])
        env_outputs = utils.EnvOutput(
            reward=tf.random.uniform([unroll_length, virtual_bs]),
            done=done,
            observation=tf.zeros([unroll_length, virtual_bs, 128],
                                 dtype=tf.float32),
            abandoned=tf.zeros_like(done),
            episode_step=tf.ones([unroll_length, virtual_bs], dtype=tf.int32))
        if use_agent_state:
            core_state = tf.zeros([virtual_bs, 64])
        else:
            core_state = training_agent.initial_state(virtual_bs)
        agent_outputs, _ = training_agent((prev_actions, env_outputs),
                                          core_state,
                                          unroll=True)
        args = Unroll(core_state, prev_actions, env_outputs, agent_outputs)

        class DummyStrategy:
            def __init__(self):
                self.num_replicas_in_sync = 1

        loss_fn = generalized_onpolicy_loss.GeneralizedOnPolicyLoss(
            training_agent,
            popart.PopArt(running_statistics.FixedMeanStd(), compensate=False),
            distribution,
            ga_advantages.GAE(lambda_=0.9),
            policy_losses.ppo(0.9),
            discount_factor=0.99,
            regularizer=policy_regularizers.KLPolicyRegularizer(entropy=0.5),
            baseline_cost=0.5,
            max_abs_reward=None,
            frame_skip=1,
            reward_scaling=10)
        loss_fn.init()
        loss, logs = ppo_training_step_utils.ppo_training_step(
            epochs_per_step=8,
            loss_fn=loss_fn,
            args=args,
            batch_mode=batch_mode,
            training_strategy=DummyStrategy(),
            virtual_batch_size=virtual_bs,
            unroll_length=unroll_length - 1,
            batches_per_step=batches_per_step,
            clip_norm=50.,
            optimizer=tf.keras.optimizers.Adam(1e-3),
            logger=utils.ProgressLogger())
        del loss
        del logs