Example #1
0
def train_eval(
    root_dir,
    tf_master='',
    env_name='HalfCheetah-v2',
    env_load_fn=suite_mujoco.load,
    random_seed=0,
    # TODO(b/127576522): rename to policy_fc_layers.
    actor_fc_layers=(200, 100),
    value_fc_layers=(200, 100),
    use_rnns=False,
    # Params for collect
    num_environment_steps=10000000,
    collect_episodes_per_iteration=30,
    num_parallel_environments=30,
    replay_buffer_capacity=1001,  # Per-environment
    # Params for train
    num_epochs=25,
    learning_rate=1e-4,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=500,
    # Params for summaries and logging
    train_checkpoint_interval=100,
    policy_checkpoint_interval=50,
    rb_checkpoint_interval=200,
    log_interval=50,
    summary_interval=50,
    summaries_flush_secs=1,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for PPO."""
  if root_dir is None:
    raise AttributeError('train_eval requires a root_dir.')

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      batched_py_metric.BatchedPyMetric(
          AverageReturnMetric,
          metric_args={'buffer_size': num_eval_episodes},
          batch_size=num_parallel_environments),
      batched_py_metric.BatchedPyMetric(
          AverageEpisodeLengthMetric,
          metric_args={'buffer_size': num_eval_episodes},
          batch_size=num_parallel_environments),
  ]
  eval_summary_writer_flush_op = eval_summary_writer.flush()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf.compat.v1.set_random_seed(random_seed)
    eval_py_env = parallel_py_environment.ParallelPyEnvironment(
        [lambda: env_load_fn(env_name)] * num_parallel_environments)
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments))
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    if use_rnns:
      actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          input_fc_layer_params=actor_fc_layers,
          output_fc_layer_params=None)
      value_net = value_rnn_network.ValueRnnNetwork(
          tf_env.observation_spec(),
          input_fc_layer_params=value_fc_layers,
          output_fc_layer_params=None)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers)
      value_net = value_network.ValueNetwork(
          tf_env.observation_spec(), fc_layer_params=value_fc_layers)

    tf_agent = ppo_agent.PPOAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        optimizer,
        actor_net=actor_net,
        value_net=value_net,
        num_epochs=num_epochs,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    environment_steps_count = environment_steps_metric.result()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]
    train_metrics = step_metrics + [
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    # Add to replay buffer and other agent specific observers.
    replay_buffer_observer = [replay_buffer.add_batch]

    collect_policy = tf_agent.collect_policy

    collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=replay_buffer_observer + train_metrics,
        num_episodes=collect_episodes_per_iteration).run()

    trajectories = replay_buffer.gather_all()

    train_op, _ = tf_agent.train(experience=trajectories)

    with tf.control_dependencies([train_op]):
      clear_replay_op = replay_buffer.clear()

    with tf.control_dependencies([clear_replay_op]):
      train_op = tf.identity(train_op)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=step_metrics))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session(tf_master) as sess:
      # Initialize graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())

      collect_time = 0
      train_time = 0
      timed_at_step = sess.run(global_step)
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      while sess.run(environment_steps_count) < num_environment_steps:
        global_step_val = sess.run(global_step)
        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
          sess.run(eval_summary_writer_flush_op)

        start_time = time.time()
        sess.run(collect_op)
        collect_time += time.time() - start_time
        start_time = time.time()
        total_loss, _ = sess.run([train_op, summary_ops])
        train_time += time.time() - start_time

        global_step_val = sess.run(global_step)
        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss)
          steps_per_sec = (
              (global_step_val - timed_at_step) / (collect_time + train_time))
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          logging.info('%s', 'collect_time = {}, train_time = {}'.format(
              collect_time, train_time))
          timed_at_step = global_step_val
          collect_time = 0
          train_time = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

      # One final eval before exiting.
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )
      sess.run(eval_summary_writer_flush_op)
Example #2
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    env_load_fn=suite_mujoco.load,
    random_seed=0,
    # TODO(b/127576522): rename to policy_fc_layers.
    actor_fc_layers=(200, 100),
    value_fc_layers=(200, 100),
    use_rnns=False,
    # Params for collect
    num_environment_steps=10000000,
    collect_episodes_per_iteration=30,
    num_parallel_environments=30,
    replay_buffer_capacity=1001,  # Per-environment
    # Params for train
    num_epochs=25,
    learning_rate=1e-4,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=500,
    # Params for summaries and logging
    log_interval=50,
    summary_interval=50,
    summaries_flush_secs=1,
    use_tf_functions=True,
    debug_summaries=False,
    summarize_grads_and_vars=False):
  """A simple train and eval for PPO."""
  if root_dir is None:
    raise AttributeError('train_eval requires a root_dir.')

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf.compat.v1.set_random_seed(random_seed)
    eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments))
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    if use_rnns:
      actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          input_fc_layer_params=actor_fc_layers,
          output_fc_layer_params=None)
      value_net = value_rnn_network.ValueRnnNetwork(
          tf_env.observation_spec(),
          input_fc_layer_params=value_fc_layers,
          output_fc_layer_params=None)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers)
      value_net = value_network.ValueNetwork(
          tf_env.observation_spec(), fc_layer_params=value_fc_layers)

    tf_agent = ppo_agent.PPOAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        optimizer,
        actor_net=actor_net,
        value_net=value_net,
        num_epochs=num_epochs,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]

    train_metrics = step_metrics + [
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration)

    if use_tf_functions:
      # TODO(b/123828980): Enable once the cause for slowdown was identified.
      collect_driver.run = common.function(collect_driver.run, autograph=False)
      tf_agent.train = common.function(tf_agent.train, autograph=False)

    collect_time = 0
    train_time = 0
    timed_at_step = global_step.numpy()

    while environment_steps_metric.result() < num_environment_steps:
      global_step_val = global_step.numpy()
      if global_step_val % eval_interval == 0:
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )

      start_time = time.time()
      collect_driver.run()
      collect_time += time.time() - start_time

      start_time = time.time()
      trajectories = replay_buffer.gather_all()
      total_loss, _ = tf_agent.train(experience=trajectories)
      replay_buffer.clear()
      train_time += time.time() - start_time

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics)

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, total_loss)
        steps_per_sec = (
            (global_step_val - timed_at_step) / (collect_time + train_time))
        logging.info('%.3f steps/sec', steps_per_sec)
        logging.info('collect_time = {}, train_time = {}'.format(
            collect_time, train_time))
        with tf.compat.v2.summary.record_if(True):
          tf.compat.v2.summary.scalar(
              name='global_steps_per_sec', data=steps_per_sec, step=global_step)

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0

    # One final eval before exiting.
    metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
Example #3
0
def create_ppo_agent(env, global_step, FLAGS):

    actor_fc_layers = (512, 256)
    value_fc_layers = (512, 256)

    lstm_fc_input = (1024, 512)
    lstm_size = (256, )
    lstm_fc_output = (256, 256)

    minimap_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=16,
                               kernel_size=(5, 5),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Conv2D(filters=32,
                               kernel_size=(3, 3),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=256, activation='relu')
    ])

    screen_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=16,
                               kernel_size=(5, 5),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Conv2D(filters=32,
                               kernel_size=(3, 3),
                               strides=(2, 2),
                               activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=256, activation='relu')
    ])

    info_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Dense(units=128, activation='relu'),
        tf.keras.layers.Dense(units=128, activation='relu')
    ])

    entities_preprocessing = tf.keras.models.Sequential([
        tf.keras.layers.Conv1D(filters=4, kernel_size=4, activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units=256, activation='relu')
    ])

    actor_preprocessing_layers = {
        'minimap': minimap_preprocessing,
        'screen': screen_preprocessing,
        'info': info_preprocessing,
        'entities': entities_preprocessing,
    }
    actor_preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    if FLAGS.use_lstms:
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            env.observation_spec(),
            env.action_spec(),
            preprocessing_layers=actor_preprocessing_layers,
            preprocessing_combiner=actor_preprocessing_combiner,
            input_fc_layer_params=lstm_fc_input,
            output_fc_layer_params=lstm_fc_output,
            lstm_size=lstm_size)
    else:
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            input_tensor_spec=env.observation_spec(),
            output_tensor_spec=env.action_spec(),
            preprocessing_layers=actor_preprocessing_layers,
            preprocessing_combiner=actor_preprocessing_combiner,
            fc_layer_params=actor_fc_layers,
            activation_fn=tf.keras.activations.tanh)

    value_preprocessing_layers = {
        'minimap': minimap_preprocessing,
        'screen': screen_preprocessing,
        'info': info_preprocessing,
        'entities': entities_preprocessing,
    }
    value_preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    if FLAGS.use_lstms:
        value_net = value_rnn_network.ValueRnnNetwork(
            env.observation_spec(),
            preprocessing_layers=value_preprocessing_layers,
            preprocessing_combiner=value_preprocessing_combiner,
            input_fc_layer_params=lstm_fc_input,
            output_fc_layer_params=lstm_fc_output,
            lstm_size=lstm_size)
    else:
        value_net = value_network.ValueNetwork(
            env.observation_spec(),
            preprocessing_layers=value_preprocessing_layers,
            preprocessing_combiner=value_preprocessing_combiner,
            fc_layer_params=value_fc_layers,
            activation_fn=tf.keras.activations.tanh)

    optimizer = tf.compat.v1.train.AdamOptimizer(
        learning_rate=FLAGS.learning_rate)

    # commented out values are the defaults
    tf_agent = my_ppo_agent.PPOAgent(
        time_step_spec=env.time_step_spec(),
        action_spec=env.action_spec(),
        optimizer=optimizer,
        actor_net=actor_net,
        value_net=value_net,
        importance_ratio_clipping=0.1,
        # lambda_value=0.95,
        discount_factor=0.95,
        entropy_regularization=0.003,
        # policy_l2_reg=0.0,
        # value_function_l2_reg=0.0,
        # shared_vars_l2_reg=0.0,
        # value_pred_loss_coef=0.5,
        num_epochs=FLAGS.num_epochs,
        use_gae=True,
        use_td_lambda_return=True,
        normalize_rewards=FLAGS.norm_rewards,
        reward_norm_clipping=0.0,
        normalize_observations=True,
        # log_prob_clipping=0.0,
        # KL from here...
        # To disable the fixed KL cutoff penalty, set the kl_cutoff_factor parameter to 0.0
        kl_cutoff_factor=0.0,
        kl_cutoff_coef=0.0,
        # To disable the adaptive KL penalty, set the initial_adaptive_kl_beta parameter to 0.0
        initial_adaptive_kl_beta=0.0,
        adaptive_kl_target=0.00,
        adaptive_kl_tolerance=0.0,  # ...to here.
        # gradient_clipping=None,
        value_clipping=0.5,
        # check_numerics=False,
        # compute_value_and_advantage_in_train=True,
        # update_normalizers_in_train=True,
        # debug_summaries=False,
        # summarize_grads_and_vars=False,
        train_step_counter=global_step,
        # name='PPOClipAgent'
    )

    tf_agent.initialize()
    return tf_agent
Example #4
0
def train_eval(
        root_dir,
        # env_name='HalfCheetah-v2',
        # env_load_fn=suite_mujoco.load,
        env_load_fn=None,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=int(1e7),
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,  # use_tf_functions=False,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)

        # eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        # tf_env = tf_py_environment.TFPyEnvironment(
        #     parallel_py_environment.ParallelPyEnvironment(
        #         [lambda: env_load_fn(env_name)] * num_parallel_environments))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.wrap_env(RectEnv()))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: suite_gym.wrap_env(RectEnv())] *
                num_parallel_environments))

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        preprocessing_layers = {
            'target':
            tf.keras.models.Sequential([
                # tf.keras.applications.MobileNetV2(
                #     input_shape=(64, 64, 1), include_top=False, weights=None),
                # tf.keras.layers.Conv2D(1, 6),
                easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)),
                tf.keras.layers.Flatten()
            ]),
            'canvas':
            tf.keras.models.Sequential([
                # tf.keras.applications.MobileNetV2(
                #     input_shape=(64, 64, 1), include_top=False, weights=None),
                # tf.keras.layers.Conv2D(1, 6),
                easy.encoder((CANVAS_WIDTH, CANVAS_WIDTH, 1)),
                tf.keras.layers.Flatten()
            ]),
            'coord':
            tf.keras.models.Sequential([
                tf.keras.layers.Dense(64),
                tf.keras.layers.Dense(64),
                tf.keras.layers.Flatten()
            ])
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                preprocessing_layers=preprocessing_layers,
                preprocessing_combiner=preprocessing_combiner)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            max_to_keep=5,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  max_to_keep=5,
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Example #5
0
def construct_multigrid_networks(observation_spec,
                                 action_spec,
                                 use_rnns=True,
                                 actor_fc_layers=(200, 100),
                                 value_fc_layers=(200, 100),
                                 lstm_size=(128, ),
                                 conv_filters=8,
                                 conv_kernel=3,
                                 scalar_fc=5,
                                 scalar_name='direction',
                                 scalar_dim=4,
                                 random_z=False,
                                 xy_dim=None):
    """Creates an actor and critic network designed for use with MultiGrid.

  A convolution layer processes the image and a dense layer processes the
  direction the agent is facing. These are fed into some fully connected layers
  and an LSTM.

  Args:
    observation_spec: A tf-agents observation spec.
    action_spec: A tf-agents action spec.
    use_rnns: If True, will construct RNN networks.
    actor_fc_layers: Dimension and number of fully connected layers in actor.
    value_fc_layers: Dimension and number of fully connected layers in critic.
    lstm_size: Number of cells in each LSTM layers.
    conv_filters: Number of convolution filters.
    conv_kernel: Size of the convolution kernel.
    scalar_fc: Number of neurons in the fully connected layer processing
      the scalar input.
    scalar_name: Name of the scalar input.
    scalar_dim: Highest possible value for the scalar input. Used to convert to
      one-hot representation.
    random_z: If True, will provide an additional layer to process a randomly
      generated float input vector.
    xy_dim: If not None, will provide two additional layers to process 'x' and
      'y' inputs. The dimension provided is the maximum value of x and y, and
      is used to create one-hot representation.

  Returns:
    A tf-agents ActorDistributionRnnNetwork for the actor, and a ValueRnnNetwork
    for the critic.
  """
    preprocessing_layers = {
        'image':
        tf.keras.models.Sequential([
            cast_and_scale(),
            tf.keras.layers.Conv2D(conv_filters, conv_kernel),
            tf.keras.layers.Flatten()
        ]),
        scalar_name:
        tf.keras.models.Sequential(
            [one_hot_layer(scalar_dim),
             tf.keras.layers.Dense(scalar_fc)])
    }
    if random_z:
        preprocessing_layers['random_z'] = tf.keras.models.Sequential(
            [tf.keras.layers.Lambda(lambda x: x)])  # Identity layer
    if xy_dim is not None:
        preprocessing_layers['x'] = tf.keras.models.Sequential(
            [one_hot_layer(xy_dim)])
        preprocessing_layers['y'] = tf.keras.models.Sequential(
            [one_hot_layer(xy_dim)])

    preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

    if use_rnns:
        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=actor_fc_layers,
            output_fc_layer_params=None,
            lstm_size=lstm_size)
        value_net = value_rnn_network.ValueRnnNetwork(
            observation_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            input_fc_layer_params=value_fc_layers,
            output_fc_layer_params=None,
            lstm_size=lstm_size)
    else:
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            fc_layer_params=actor_fc_layers,
            activation_fn=tf.keras.activations.tanh)
        value_net = value_network.ValueNetwork(
            observation_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            fc_layer_params=value_fc_layers,
            activation_fn=tf.keras.activations.tanh)

    return actor_net, value_net
def train_eval(
        root_dir,
        env_name=None,
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        inference_fc_layers=(200, 100),
        use_rnns=None,
        dim_z=4,
        categorical=True,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        entropy_regularization=None,
        kl_posteriors_penalty=None,
        mock_inference=None,
        mock_reward=None,
        l2_distance=None,
        rl_steps=None,
        inference_steps=None,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)

        def _env_load_fn(env_name):
            diayn_wrapper = (
                lambda x: diayn_gym_env.DiaynGymEnv(x, dim_z, categorical))
            return env_load_fn(
                env_name,
                gym_env_wrappers=[diayn_wrapper],
            )

        eval_tf_env = tf_py_environment.TFPyEnvironment(_env_load_fn(env_name))
        if num_parallel_environments == 1:
            py_env = _env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: _env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)

        augmented_time_step_spec = tf_env.time_step_spec()
        augmented_observation_spec = augmented_time_step_spec.observation
        observation_spec = augmented_observation_spec['observation']
        z_spec = augmented_observation_spec['z']
        reward_spec = augmented_time_step_spec.reward
        action_spec = tf_env.action_spec()
        time_step_spec = ts.time_step_spec(observation_spec)
        infer_from_com = False
        if env_name == "AntRandGoalEval-v1":
            infer_from_com = True
        if infer_from_com:
            input_inference_spec = tspec.BoundedTensorSpec(
                shape=[2],
                dtype=tf.float64,
                minimum=-1.79769313e+308,
                maximum=1.79769313e+308,
                name='body_com')
        else:
            input_inference_spec = observation_spec

        if tensor_spec.is_discrete(z_spec):
            _preprocessing_combiner = OneHotConcatenateLayer(dim_z)
        else:
            _preprocessing_combiner = DictConcatenateLayer()

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                augmented_observation_spec,
                action_spec,
                preprocessing_combiner=_preprocessing_combiner,
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                augmented_observation_spec,
                preprocessing_combiner=_preprocessing_combiner,
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                augmented_observation_spec,
                action_spec,
                preprocessing_combiner=_preprocessing_combiner,
                fc_layer_params=actor_fc_layers,
                name="actor_net")
            value_net = value_network.ValueNetwork(
                augmented_observation_spec,
                preprocessing_combiner=_preprocessing_combiner,
                fc_layer_params=value_fc_layers,
                name="critic_net")
        inference_net = actor_distribution_network.ActorDistributionNetwork(
            input_tensor_spec=input_inference_spec,
            output_tensor_spec=z_spec,
            fc_layer_params=inference_fc_layers,
            continuous_projection_net=normal_projection_net,
            name="inference_net")

        tf_agent = ppo_diayn_agent.PPODiaynAgent(
            augmented_time_step_spec,
            action_spec,
            z_spec,
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            inference_net=inference_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            entropy_regularization=entropy_regularization,
            kl_posteriors_penalty=kl_posteriors_penalty,
            mock_inference=mock_inference,
            mock_reward=mock_reward,
            infer_from_com=infer_from_com,
            l2_distance=l2_distance,
            rl_steps=rl_steps,
            inference_steps=inference_steps)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        actor_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'diayn_actor'),
                                                 actor_net=actor_net,
                                                 global_step=global_step)
        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'diayn_policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'diayn_replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        inference_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'diayn_inference'),
            inference_net=inference_net,
            global_step=global_step)

        actor_checkpointer.initialize_or_restore()
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        inference_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        # option_length = 200
        # if env_name == "Plane-v1":
        #     option_length = 10
        # dataset = replay_buffer.as_dataset(
        #         num_parallel_calls=3, sample_batch_size=num_parallel_environments,
        #         num_steps=option_length)
        # iterator_dataset = iter(dataset)

        def train_step():
            trajectories = replay_buffer.gather_all()
            #   trajectories, _ = next(iterator_dataset)
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)
                    inference_checkpointer.save(global_step=global_step_val)
                    actor_checkpointer.save(global_step=global_step_val)
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Example #7
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        env_load_fn=suite_gym.load,
        random_seed=None,
        max_ep_steps=1000,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=5000000,
        collect_episodes_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=10000,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=10,
        num_random_episodes=1,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        rb_checkpoint_interval=20000,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=10,
        use_tf_functions=True,
        debug_summaries=False,
        eval_metrics_callback=None,
        random_metrics_callback=None,
        summarize_grads_and_vars=False):

    # Set up the directories to contain the log data and model saves
    # If data already exist in these folders, then we will try to load it later.
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    random_dir = os.path.join(root_dir, 'random')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    # Create writers for logging and specify the metrics to log for each
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    random_summary_writer = tf.compat.v2.summary.create_file_writer(
        random_dir, flush_millis=summaries_flush_secs * 1000)

    random_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Set up the agent and train, recoding data at each summary_internal number of steps
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        # Load the environments. Here, we used the same for evaluation and training.
        # However, they could be different.
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(env_name, max_episode_steps=max_ep_steps))
        # tf_env = tf_py_environment.TFPyEnvironment(
        #         parallel_py_environment.ParallelPyEnvironment(
        #                 [lambda: env_load_fn(env_name, max_episode_steps=max_ep_steps)] * num_parallel_environments))

        tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name, max_episode_steps=max_ep_steps))

        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)

            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)

        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)

            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            kl_cutoff_factor=0.0,
            initial_adaptive_kl_beta=0.0,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()

        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)

        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        random_policy = random_tf_policy.RandomTFPolicy(
            eval_tf_env.time_step_spec(), eval_tf_env.action_spec())

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()

            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

                metric_utils.eager_compute(
                    random_metrics,
                    eval_tf_env,
                    random_policy,
                    num_episodes=num_random_episodes,
                    train_step=global_step,
                    summary_writer=random_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('Step: {:>6d}\tLoss: {:>+20.4f}'.format(
                    global_step_val, total_loss))

                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))

                logging.info('{:6.3f} steps/sec'.format(steps_per_sec))
                logging.info(
                    'collect_time = {:.3f}, train_time = {:.3f}'.format(
                        collect_time, train_time))

                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                saved_model_path = os.path.join(
                    saved_model_dir,
                    'policy_' + ('%d' % global_step_val).zfill(9))
                saved_model.save(saved_model_path)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            timed_at_step = global_step_val
            collect_time = 0
            train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
def train_eval(
        root_dir,
        env_name='frozen_lake',
        env_load_fn=get_env,
        max_episode_steps=50,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        entropy_regularization=0.0,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=25,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('root_dir required.')

    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    global_step.assign(0)
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_env = env_load_fn(name=env_name,
                               max_episode_steps=max_episode_steps)
        failure_state_vector = eval_env.get_failure_state_vector()
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env)

        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(name=env_name,
                                    max_episode_steps=max_episode_steps)
            ] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=entropy_regularization,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            FailedEpisodes(failure_function=functools.partial(
                failure_function_discrete,
                failure_state_vector=failure_state_vector)),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )