예제 #1
0
def main(argv):
  del argv

  if FLAGS.jax_debug_nans:
    config.update("jax_debug_nans", True)

  bottom_layers = common_stax_layers()

  if FLAGS.env_name == "Pong-v0":
    bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers

  optimizer_fun = functools.partial(ppo.optimizer_fun,
                                    step_size=FLAGS.learning_rate)

  ppo.training_loop(
      env_name=FLAGS.env_name,
      epochs=FLAGS.epochs,
      policy_net_fun=functools.partial(
          ppo.policy_net, bottom_layers=bottom_layers),
      value_net_fun=functools.partial(
          ppo.value_net, bottom_layers=bottom_layers),
      policy_optimizer_fun=optimizer_fun,
      value_optimizer_fun=optimizer_fun,
      batch_size=FLAGS.batch_size,
      num_optimizer_steps=FLAGS.num_optimizer_steps,
      boundary=FLAGS.boundary,
      random_seed=FLAGS.random_seed)
예제 #2
0
def main(argv):
    del argv
    logging.set_verbosity(FLAGS.log_level)
    bottom_layers = common_stax_layers()
    ppo.training_loop(
        env_name=FLAGS.env,
        epochs=FLAGS.epochs,
        policy_net_fun=functools.partial(ppo.policy_net,
                                         bottom_layers=bottom_layers),
        value_net_fun=functools.partial(ppo.value_net,
                                        bottom_layers=bottom_layers),
        random_seed=FLAGS.random_seed)
 def _run_training_loop(self, train_env, eval_env, output_dir):
   n_epochs = 2
   # Run the training loop.
   ppo.training_loop(
       train_env=train_env,
       eval_env=eval_env,
       epochs=n_epochs,
       policy_and_value_net_fn=functools.partial(
           ppo.policy_and_value_net,
           bottom_layers_fn=lambda: [layers.Dense(1)]),
       policy_and_value_optimizer_fn=ppo.optimizer_fn,
       n_optimizer_steps=1,
       output_dir=output_dir,
       random_seed=0)
예제 #4
0
  def run_training_loop():
    optimizer_fun = functools.partial(
        ppo.optimizer_fun, step_size=FLAGS.learning_rate)

    ppo.training_loop(
        env_name=FLAGS.env_name,
        epochs=FLAGS.epochs,
        policy_and_value_net_fun=functools.partial(
            ppo.policy_and_value_net, bottom_layers=common_layers()),
        policy_and_value_optimizer_fun=optimizer_fun,
        batch_size=FLAGS.batch_size,
        num_optimizer_steps=FLAGS.num_optimizer_steps,
        boundary=FLAGS.boundary,
        max_timestep=FLAGS.max_timestep,
        random_seed=FLAGS.random_seed)
예제 #5
0
    def run_training_loop():
        """Runs the training loop."""
        logging.info("Starting the training loop.")

        policy_and_value_net_fn = functools.partial(
            ppo.policy_and_value_net,
            bottom_layers_fn=bottom_layers_fn,
            two_towers=FLAGS.two_towers)
        policy_and_value_optimizer_fn = get_optimizer_fn(FLAGS.learning_rate)

        ppo.training_loop(
            output_dir=FLAGS.output_dir,
            train_env=train_env,
            eval_env=eval_env,
            policy_and_value_net_fn=policy_and_value_net_fn,
            policy_and_value_optimizer_fn=policy_and_value_optimizer_fn,
        )
 def _run_training_loop(self, env_name, output_dir):
   env = self.get_wrapped_env(env_name, 2)
   eval_env = self.get_wrapped_env(env_name, 2)
   n_epochs = 2
   # Run the training loop.
   ppo.training_loop(
       env=env,
       eval_env=eval_env,
       epochs=n_epochs,
       policy_and_value_net_fn=functools.partial(
           ppo.policy_and_value_net,
           bottom_layers_fn=lambda: [layers.Dense(1)]),
       policy_and_value_optimizer_fn=ppo.optimizer_fn,
       n_optimizer_steps=1,
       output_dir=output_dir,
       env_name=env_name,
       random_seed=0)
예제 #7
0
    def run_training_loop():
        """Runs the training loop."""
        policy_net_fun = None
        value_net_fun = None
        policy_and_value_net_fun = None
        policy_optimizer_fun = None
        value_optimizer_fun = None
        policy_and_value_optimizer_fun = None

        if FLAGS.combined_policy_and_value_function:
            policy_and_value_net_fun = functools.partial(
                ppo.policy_and_value_net, bottom_layers=common_layers())
            policy_and_value_optimizer_fun = get_optimizer_fun(
                FLAGS.learning_rate)
        else:
            policy_net_fun = functools.partial(ppo.policy_net,
                                               bottom_layers=common_layers())
            value_net_fun = functools.partial(ppo.value_net,
                                              bottom_layers=common_layers())
            policy_optimizer_fun = get_optimizer_fun(
                FLAGS.policy_only_learning_rate)
            value_optimizer_fun = get_optimizer_fun(
                FLAGS.value_only_learning_rate)

        ppo.training_loop(
            env=env,
            epochs=FLAGS.epochs,
            policy_net_fun=policy_net_fun,
            value_net_fun=value_net_fun,
            policy_and_value_net_fun=policy_and_value_net_fun,
            policy_optimizer_fun=policy_optimizer_fun,
            value_optimizer_fun=value_optimizer_fun,
            policy_and_value_optimizer_fun=policy_and_value_optimizer_fun,
            num_optimizer_steps=FLAGS.num_optimizer_steps,
            policy_only_num_optimizer_steps=FLAGS.
            policy_only_num_optimizer_steps,
            value_only_num_optimizer_steps=FLAGS.
            value_only_num_optimizer_steps,
            batch_size=FLAGS.batch_size,
            target_kl=FLAGS.target_kl,
            boundary=FLAGS.boundary,
            max_timestep=FLAGS.max_timestep,
            random_seed=FLAGS.random_seed,
            c1=FLAGS.value_coef,
            c2=FLAGS.entropy_coef)
예제 #8
0
def main(argv):
  del argv
  logging.set_verbosity(FLAGS.log_level)
  bottom_layers = common_stax_layers()

  if FLAGS.env_name == "Pong-v0":
    bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers

  ppo.training_loop(
      env_name=FLAGS.env_name,
      epochs=FLAGS.epochs,
      policy_net_fun=functools.partial(
          ppo.policy_net, bottom_layers=bottom_layers),
      value_net_fun=functools.partial(
          ppo.value_net, bottom_layers=bottom_layers),
      batch_size=FLAGS.batch_size,
      boundary=FLAGS.boundary,
      random_seed=FLAGS.random_seed)
 def test_training_loop(self):
     with self.tmp_dir() as output_dir:
         env = self.get_wrapped_env("CartPole-v0", 2)
         eval_env = self.get_wrapped_env("CartPole-v0", 2)
         num_epochs = 2
         batch_size = 2
         # Run the training loop.
         ppo.training_loop(env=env,
                           eval_env=eval_env,
                           epochs=num_epochs,
                           policy_and_value_net_fun=functools.partial(
                               ppo.policy_and_value_net,
                               bottom_layers_fn=lambda: [layers.Dense(1)]),
                           policy_and_value_optimizer_fun=ppo.optimizer_fun,
                           batch_size=batch_size,
                           num_optimizer_steps=1,
                           output_dir=output_dir,
                           random_seed=0)
예제 #10
0
  def run_training_loop():
    """Runs the training loop."""
    logging.info("Starting the training loop.")

    policy_and_value_net_fn = functools.partial(
        ppo.policy_and_value_net,
        bottom_layers_fn=common_layers,
        two_towers=FLAGS.two_towers)
    policy_and_value_optimizer_fn = get_optimizer_fn(FLAGS.learning_rate)

    random_seed = None
    try:
      random_seed = int(FLAGS.random_seed)
    except Exception:  # pylint: disable=broad-except
      pass

    ppo.training_loop(
        env=env,
        epochs=FLAGS.epochs,
        policy_and_value_net_fn=policy_and_value_net_fn,
        policy_and_value_optimizer_fn=policy_and_value_optimizer_fn,
        n_optimizer_steps=FLAGS.n_optimizer_steps,
        print_every_optimizer_steps=FLAGS.print_every_optimizer_steps,
        batch_size=FLAGS.batch_size,
        target_kl=FLAGS.target_kl,
        boundary=FLAGS.boundary,
        max_timestep=FLAGS.truncation_timestep,
        max_timestep_eval=FLAGS.truncation_timestep_eval,
        random_seed=random_seed,
        c1=FLAGS.value_coef,
        c2=FLAGS.entropy_coef,
        gamma=FLAGS.gamma,
        lambda_=FLAGS.lambda_,
        epsilon=FLAGS.epsilon,
        enable_early_stopping=FLAGS.enable_early_stopping,
        output_dir=FLAGS.output_dir,
        eval_every_n=FLAGS.eval_every_n,
        done_frac_for_policy_save=FLAGS.done_frac_for_policy_save,
        eval_env=eval_env,
        n_evals=FLAGS.n_evals,
        env_name=str(FLAGS.env_problem_name),
        len_history_for_policy=int(FLAGS.len_history_for_policy),
    )
예제 #11
0
def run_training_loop():
    """Run the PPO training loop."""

    policy_net_fun = None
    value_net_fun = None
    policy_and_value_net_fun = None
    policy_optimizer_fun = None
    value_optimizer_fun = None
    policy_and_value_optimizer_fun = None

    if FLAGS.combined_policy_and_value_function:
        policy_and_value_net_fun = functools.partial(
            ppo.policy_and_value_net, bottom_layers=common_layers())
        policy_and_value_optimizer_fun = get_optimizer_fun(
            FLAGS.policy_and_value_net_learning_rate)
    else:
        policy_net_fun = functools.partial(ppo.policy_net,
                                           bottom_layers=common_layers())
        value_net_fun = functools.partial(ppo.value_net,
                                          bottom_layers=common_layers())
        policy_optimizer_fun = get_optimizer_fun(
            FLAGS.policy_net_learning_rate)
        value_optimizer_fun = get_optimizer_fun(FLAGS.value_net_learning_rate)

    ppo.training_loop(
        env_name=FLAGS.env_name,
        epochs=FLAGS.epochs,
        policy_net_fun=policy_net_fun,
        value_net_fun=value_net_fun,
        policy_and_value_net_fun=policy_and_value_net_fun,
        policy_optimizer_fun=policy_optimizer_fun,
        value_optimizer_fun=value_optimizer_fun,
        policy_and_value_optimizer_fun=policy_and_value_optimizer_fun,
        batch_size=FLAGS.batch_size,
        num_optimizer_steps=FLAGS.num_optimizer_steps,
        boundary=FLAGS.boundary,
        max_timestep=FLAGS.max_timestep,
        random_seed=FLAGS.random_seed)
 def test_training_loop(self):
     env = gym.make("CartPole-v0")
     # Usually gym envs are wrapped in TimeLimit wrapper.
     env = gym_utils.remove_time_limit_wrapper(env)
     # Limit this to a small number for tests.
     env = gym.wrappers.TimeLimit(env, max_episode_steps=2)
     num_epochs = 2
     batch_size = 2
     _, rewards, val_losses, ppo_objectives = ppo.training_loop(
         env=env,
         epochs=num_epochs,
         batch_size=batch_size,
         num_optimizer_steps=1)
     self.assertLen(rewards, num_epochs)
     self.assertLen(val_losses, num_epochs)
     self.assertLen(ppo_objectives, num_epochs)
예제 #13
0
 def test_training_loop_policy_and_value_function(self):
     env = self.get_wrapped_env("CartPole-v0", 2)
     num_epochs = 2
     batch_size = 2
     # Run the training loop.
     _, rewards, val_losses, ppo_objectives = ppo.training_loop(
         env=env,
         epochs=num_epochs,
         policy_and_value_net_fun=functools.partial(
             ppo.policy_and_value_net, bottom_layers=[layers.Dense(1)]),
         policy_and_value_optimizer_fun=ppo.optimizer_fun,
         batch_size=batch_size,
         num_optimizer_steps=1,
         random_seed=0)
     self.assertLen(rewards, num_epochs)
     self.assertLen(val_losses, num_epochs)
     self.assertLen(ppo_objectives, num_epochs)
 def test_training_loop(self):
     env = gym.make("CartPole-v0")
     # Usually gym envs are wrapped in TimeLimit wrapper.
     env = gym_utils.remove_time_limit_wrapper(env)
     # Limit this to a small number for tests.
     env = gym.wrappers.TimeLimit(env, max_episode_steps=2)
     num_epochs = 2
     batch_size = 2
     # Run the training loop.
     _, rewards, val_losses, ppo_objectives = ppo.training_loop(
         env=env,
         epochs=num_epochs,
         policy_net_fun=functools.partial(ppo.policy_net,
                                          bottom_layers=[stax.Dense(1)]),
         value_net_fun=functools.partial(ppo.value_net,
                                         bottom_layers=[stax.Dense(1)]),
         batch_size=batch_size,
         num_optimizer_steps=1,
         random_seed=0)
     self.assertLen(rewards, num_epochs)
     self.assertLen(val_losses, num_epochs)
     self.assertLen(ppo_objectives, num_epochs)
예제 #15
0
    def run_training_loop():
        """Runs the training loop."""
        policy_net_fun = None
        value_net_fun = None
        policy_and_value_net_fun = None
        policy_optimizer_fun = None
        value_optimizer_fun = None
        policy_and_value_optimizer_fun = None

        if FLAGS.combined_network:
            policy_and_value_net_fun = functools.partial(
                ppo.policy_and_value_net,
                bottom_layers_fn=common_layers,
                two_towers=FLAGS.two_towers)
            policy_and_value_optimizer_fun = get_optimizer_fun(
                FLAGS.learning_rate)
        else:
            policy_net_fun = functools.partial(ppo.policy_net,
                                               bottom_layers=common_layers())
            value_net_fun = functools.partial(ppo.value_net,
                                              bottom_layers=common_layers())
            policy_optimizer_fun = get_optimizer_fun(
                FLAGS.policy_only_learning_rate)
            value_optimizer_fun = get_optimizer_fun(
                FLAGS.value_only_learning_rate)

        random_seed = None
        try:
            random_seed = int(FLAGS.random_seed)
        except Exception:  # pylint: disable=broad-except
            pass

        ppo.training_loop(
            env=env,
            epochs=FLAGS.epochs,
            policy_net_fun=policy_net_fun,
            value_net_fun=value_net_fun,
            policy_and_value_net_fun=policy_and_value_net_fun,
            policy_optimizer_fun=policy_optimizer_fun,
            value_optimizer_fun=value_optimizer_fun,
            policy_and_value_optimizer_fun=policy_and_value_optimizer_fun,
            num_optimizer_steps=FLAGS.num_optimizer_steps,
            policy_only_num_optimizer_steps=FLAGS.
            policy_only_num_optimizer_steps,
            value_only_num_optimizer_steps=FLAGS.
            value_only_num_optimizer_steps,
            print_every_optimizer_steps=FLAGS.print_every_optimizer_steps,
            batch_size=FLAGS.batch_size,
            target_kl=FLAGS.target_kl,
            boundary=FLAGS.boundary,
            max_timestep=FLAGS.truncation_timestep,
            max_timestep_eval=FLAGS.truncation_timestep_eval,
            random_seed=random_seed,
            c1=FLAGS.value_coef,
            c2=FLAGS.entropy_coef,
            gamma=FLAGS.gamma,
            lambda_=FLAGS.lambda_,
            epsilon=FLAGS.epsilon,
            enable_early_stopping=FLAGS.enable_early_stopping,
            output_dir=FLAGS.output_dir,
            eval_every_n=FLAGS.eval_every_n,
            eval_env=eval_env)
예제 #16
0
def main(unused_argv):
    logging.set_verbosity(FLAGS.log_level)
    ppo.training_loop(env_name=FLAGS.env, epochs=FLAGS.epochs)