def get_networks(tf_env, networks_layers):
    assert isinstance(networks_layers, dict)
    actor_fc_layers = networks_layers["actor_net"]
    value_fc_layers = networks_layers["value_net"]
    actor_net = masked_networks.MaskedActorDistributionNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers)
    value_net = masked_networks.MaskedValueNetwork(
        tf_env.observation_spec(), fc_layer_params=value_fc_layers)
    return actor_net, value_net
    def __init__(self,
                 root_dir,
                 game_type='Hanabi-Full',
                 num_players=4,
                 actor_fc_layers=(100, ),
                 value_fc_layers=(100, ),
                 use_value_network=False
                 ):
        tf.reset_default_graph()
        self.sess = tf.Session()
        tf.compat.v1.enable_resource_variables()

        pyhanabi_env = rl_env.make(environment_name=game_type, num_players=num_players)
        py_env = PyhanabiEnvWrapper(pyhanabi_env)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)

        with self.sess.as_default():
            # init the agent
            actor_net = masked_networks.MaskedActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers
            )
            value_network = None
            if use_value_network:
                value_network = MaskedValueNetwork(
                    tf_env.observation_spec(),
                    fc_layer_params=value_fc_layers
                )

            global_step = tf.compat.v1.train.get_or_create_global_step()  # necessary ??? => Yes baby

            tf_agent = reinforce_agent.ReinforceAgent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_network=actor_net,
                value_network=value_network if use_value_network else None,
                value_estimation_loss_coef=.2,
                gamma=.9,
                optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=1e-3),
                debug_summaries=False,
                summarize_grads_and_vars=False,
                train_step_counter=global_step
                )

            self.policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

            # load checkpoint
            #train_dir = os.path.join(root_dir, 'train')
            self.policy.initialize(None)
            self.policy.restore(root_dir)
            init_agent_op = tf_agent.initialize()

            self.sess.run(init_agent_op)
Ejemplo n.º 3
0
def train_eval(
    root_dir,
    tf_master='',
    env_name='Hanabi-Full',
    num_players=4,
    env_load_fn=None,
    random_seed=0,
    # TODO(b/127576522): rename to policy_fc_layers.
    actor_fc_layers=(150, 75),
    value_fc_layers=(150, 75),
    actor_fc_layers_rnn=(150,),
    value_fc_layers_rnn=(150,),
    use_rnns=False,
    # Params for collect
    num_environment_steps=int(1e09),
    collect_episodes_per_iteration=30,
    num_parallel_environments=30,
    replay_buffer_capacity=1001,  # Per-environment
    # Params for train
    num_epochs=25,
    learning_rate=1e-4,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=500,
    # Params for summaries and logging
    train_checkpoint_interval=2000,
    policy_checkpoint_interval=1000,
    rb_checkpoint_interval=4000,
    log_interval=50,
    summary_interval=50,
    summaries_flush_secs=1,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for PPO."""
  if root_dir is None:
    raise AttributeError('train_eval requires a root_dir.')

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      batched_py_metric.BatchedPyMetric(
          AverageReturnMetric,
          metric_args={'buffer_size': num_eval_episodes},
          batch_size=num_parallel_environments),
      batched_py_metric.BatchedPyMetric(
          AverageEpisodeLengthMetric,
          metric_args={'buffer_size': num_eval_episodes},
          batch_size=num_parallel_environments),
  ]
  eval_summary_writer_flush_op = eval_summary_writer.flush()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf.compat.v1.set_random_seed(random_seed)
    eval_py_env2 = env_load_fn(env_name, num_players)
    eval_py_env = parallel_py_environment.ParallelPyEnvironment(
        [lambda: env_load_fn(env_name, num_players)] * num_parallel_environments)
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name, num_players)] * num_parallel_environments))
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    if use_rnns:
      print('using rnns!')
      actor_net = masked_networks.MaskedActorDistributionRnnNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          input_fc_layer_params=actor_fc_layers_rnn,
          output_fc_layer_params=None, 
          lstm_size=(75,75),)
    #  value_net = masked_networks.MaskedValueRnnNetwork(
    #      tf_env.observation_spec(),
    #      input_fc_layer_params=value_fc_layers_rnn,
    #      output_fc_layer_params=None,
    #      lstm_size=(256,256),)
      value_net = masked_networks.MaskedValueNetwork(
          tf_env.observation_spec(), fc_layer_params=value_fc_layers)
    else:
      actor_net = masked_networks.MaskedActorDistributionNetwork(
          tf_env.observation_spec(),
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers)
      value_net = masked_networks.MaskedValueNetwork(
          tf_env.observation_spec(), fc_layer_params=value_fc_layers)

    tf_agent = ppo_agent.PPOAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        optimizer,
        actor_net=actor_net,
        value_net=value_net,
        num_epochs=num_epochs,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
        normalize_observations=False) # cause the observations also include the 0-1 mask

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)

    eval_py_policy2 = py_tf_policy.PyTFPolicy(tf_agent.policy)
    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    environment_steps_count = environment_steps_metric.result()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]
    train_metrics = step_metrics + [
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    # Add to replay buffer and other agent specific observers.
    replay_buffer_observer = [replay_buffer.add_batch]

    collect_policy = tf_agent.collect_policy

    collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=replay_buffer_observer + train_metrics,
        num_episodes=collect_episodes_per_iteration).run()

    trajectories = replay_buffer.gather_all()

    train_op, _ = tf_agent.train(experience=trajectories)

    with tf.control_dependencies([train_op]):
      clear_replay_op = replay_buffer.clear()

    with tf.control_dependencies([clear_replay_op]):
      train_op = tf.identity(train_op)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=step_metrics))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(
            train_step=global_step, step_metrics=step_metrics)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session(tf_master) as sess:
      # Initialize graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())

      collect_time = 0
      train_time = 0
      timed_at_step = sess.run(global_step)
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      while sess.run(environment_steps_count) < num_environment_steps:
        global_step_val = sess.run(global_step)
        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
          sess.run(eval_summary_writer_flush_op)
          # print('AVG RETURN:', compute_avg_return(eval_py_env2, eval_py_policy2))

        start_time = time.time()
        sess.run(collect_op)
        collect_time += time.time() - start_time
        start_time = time.time()
        total_loss, _ = sess.run([train_op, summary_ops])
        train_time += time.time() - start_time

        global_step_val = sess.run(global_step)
        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss)
          steps_per_sec = (
              (global_step_val - timed_at_step) / (collect_time + train_time))
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          logging.info('%s', 'collect_time = {}, train_time = {}'.format(
              collect_time, train_time))
          timed_at_step = global_step_val
          collect_time = 0
          train_time = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

      # One final eval before exiting.
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )
      sess.run(eval_summary_writer_flush_op)
Ejemplo n.º 4
0
def train_eval(
    root_dir,
    num_iterations=10000000,
    actor_fc_layers=(100,),
    value_net_fc_layers=(100,),
    use_value_network=False,
    # Params for collect
    collect_episodes_per_iteration=30,
    replay_buffer_capacity=1001,
    # Params for train
    learning_rate=1e-3,
    gamma=0.9,
    gradient_clipping=None,
    normalize_returns=True,
    value_estimation_loss_coef=0.2,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=500,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=2000,
    policy_checkpoint_interval=1000,
    rb_checkpoint_interval=4000,
    log_interval=50,
    summary_interval=50,
    summaries_flush_secs=1,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for Reinforce."""

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      batched_py_metric.BatchedPyMetric(
          AverageReturnMetric,
          metric_args={'buffer_size': num_eval_episodes},
          batch_size=30),
      batched_py_metric.BatchedPyMetric(
          AverageEpisodeLengthMetric,
          metric_args={'buffer_size': num_eval_episodes},
          batch_size=30),
  ]
  eval_summary_writer_flush_op = eval_summary_writer.flush()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    eval_py_env = parallel_py_environment.ParallelPyEnvironment([lambda: load_env("Hanabi-Small", 4)] * 30)
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: load_env("Hanabi-Small", 4)] * 30))
    # tf_env = tf_py_environment.TFPyEnvironment(load_env())

    # TODO(b/127870767): Handle distributions without gin.
    actor_net = masked_networks.MaskedActorDistributionNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers)

    if use_value_network:
      value_net = masked_networks.MaskedValueNetwork(
          tf_env.time_step_spec().observation,
          fc_layer_params=value_net_fc_layers)

    tf_agent = reinforce_agent.ReinforceAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        value_network=value_net if use_value_network else None,
        value_estimation_loss_coef=value_estimation_loss_coef,
        gamma=gamma,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        normalize_returns=normalize_returns,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)
    # eval_py_policy_custom_return = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    collect_policy = tf_agent.collect_policy

    collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration).run()

    experience = replay_buffer.gather_all()
    train_op = tf_agent.train(experience)
    clear_rb_op = replay_buffer.clear()

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      # TODO(b/126239733): Remove once Periodically can be saved.
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())

      # Compute evaluation metrics.
      global_step_call = sess.make_callable(global_step)
      global_step_val = global_step_call()
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      clear_rb_call = sess.make_callable(clear_rb_op)

      timed_at_step = global_step_call()
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        total_loss, _ = train_step_call()
        clear_rb_call()
        time_acc += time.time() - start_time
        global_step_val = global_step_call()

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
Ejemplo n.º 5
0
    def __init__(
            self,
            root_dir,
            game_type,
            num_players,
            actor_fc_layers=(150, 75),
            value_fc_layers=(150, 75),
            actor_fc_layers_rnn=(150, ),
            value_fc_layers_rnn=(150, ),
            lstm_size=(75, 75),
            use_rnns=False,
    ):

        tf.reset_default_graph()
        self.sess = tf.Session()
        tf.compat.v1.enable_resource_variables()

        pyhanabi_env = rl_env.make(environment_name=game_type,
                                   num_players=num_players)
        py_env = PyhanabiEnvWrapper(pyhanabi_env)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)

        with self.sess.as_default():
            # init the agent
            if use_rnns:
                actor_net = masked_networks.MaskedActorDistributionRnnNetwork(
                    tf_env.observation_spec(),
                    tf_env.action_spec(),
                    input_fc_layer_params=actor_fc_layers_rnn,
                    output_fc_layer_params=None,
                    lstm_size=lstm_size,
                )
                value_net = MaskedValueNetwork(tf_env.observation_spec(),
                                               fc_layer_params=value_fc_layers)
            else:
                actor_net = masked_networks.MaskedActorDistributionNetwork(
                    tf_env.observation_spec(),
                    tf_env.action_spec(),
                    fc_layer_params=actor_fc_layers)
                value_net = MaskedValueNetwork(tf_env.observation_spec(),
                                               fc_layer_params=value_fc_layers)

            global_step = tf.compat.v1.train.get_or_create_global_step(
            )  # necessary ???
            tf_agent = ppo_agent.PPOAgent(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                actor_net=actor_net,
                value_net=value_net,
                train_step_counter=global_step,
                normalize_observations=False
            )  # cause the observations also include the 0-1 mask

            self.policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

            # load checkpoint
            train_dir = os.path.join(root_dir, 'train')
            self.policy.initialize(None)
            self.policy.restore(root_dir)
            init_agent_op = tf_agent.initialize()

            self.sess.run(init_agent_op)