def main(argv): del argv if FLAGS.jax_debug_nans: config.update("jax_debug_nans", True) bottom_layers = common_stax_layers() if FLAGS.env_name == "Pong-v0": bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers optimizer_fun = functools.partial(ppo.optimizer_fun, step_size=FLAGS.learning_rate) ppo.training_loop( env_name=FLAGS.env_name, epochs=FLAGS.epochs, policy_net_fun=functools.partial( ppo.policy_net, bottom_layers=bottom_layers), value_net_fun=functools.partial( ppo.value_net, bottom_layers=bottom_layers), policy_optimizer_fun=optimizer_fun, value_optimizer_fun=optimizer_fun, batch_size=FLAGS.batch_size, num_optimizer_steps=FLAGS.num_optimizer_steps, boundary=FLAGS.boundary, random_seed=FLAGS.random_seed)
def main(argv): del argv logging.set_verbosity(FLAGS.log_level) bottom_layers = common_stax_layers() ppo.training_loop( env_name=FLAGS.env, epochs=FLAGS.epochs, policy_net_fun=functools.partial(ppo.policy_net, bottom_layers=bottom_layers), value_net_fun=functools.partial(ppo.value_net, bottom_layers=bottom_layers), random_seed=FLAGS.random_seed)
def _run_training_loop(self, train_env, eval_env, output_dir): n_epochs = 2 # Run the training loop. ppo.training_loop( train_env=train_env, eval_env=eval_env, epochs=n_epochs, policy_and_value_net_fn=functools.partial( ppo.policy_and_value_net, bottom_layers_fn=lambda: [layers.Dense(1)]), policy_and_value_optimizer_fn=ppo.optimizer_fn, n_optimizer_steps=1, output_dir=output_dir, random_seed=0)
def run_training_loop(): optimizer_fun = functools.partial( ppo.optimizer_fun, step_size=FLAGS.learning_rate) ppo.training_loop( env_name=FLAGS.env_name, epochs=FLAGS.epochs, policy_and_value_net_fun=functools.partial( ppo.policy_and_value_net, bottom_layers=common_layers()), policy_and_value_optimizer_fun=optimizer_fun, batch_size=FLAGS.batch_size, num_optimizer_steps=FLAGS.num_optimizer_steps, boundary=FLAGS.boundary, max_timestep=FLAGS.max_timestep, random_seed=FLAGS.random_seed)
def run_training_loop(): """Runs the training loop.""" logging.info("Starting the training loop.") policy_and_value_net_fn = functools.partial( ppo.policy_and_value_net, bottom_layers_fn=bottom_layers_fn, two_towers=FLAGS.two_towers) policy_and_value_optimizer_fn = get_optimizer_fn(FLAGS.learning_rate) ppo.training_loop( output_dir=FLAGS.output_dir, train_env=train_env, eval_env=eval_env, policy_and_value_net_fn=policy_and_value_net_fn, policy_and_value_optimizer_fn=policy_and_value_optimizer_fn, )
def _run_training_loop(self, env_name, output_dir): env = self.get_wrapped_env(env_name, 2) eval_env = self.get_wrapped_env(env_name, 2) n_epochs = 2 # Run the training loop. ppo.training_loop( env=env, eval_env=eval_env, epochs=n_epochs, policy_and_value_net_fn=functools.partial( ppo.policy_and_value_net, bottom_layers_fn=lambda: [layers.Dense(1)]), policy_and_value_optimizer_fn=ppo.optimizer_fn, n_optimizer_steps=1, output_dir=output_dir, env_name=env_name, random_seed=0)
def run_training_loop(): """Runs the training loop.""" policy_net_fun = None value_net_fun = None policy_and_value_net_fun = None policy_optimizer_fun = None value_optimizer_fun = None policy_and_value_optimizer_fun = None if FLAGS.combined_policy_and_value_function: policy_and_value_net_fun = functools.partial( ppo.policy_and_value_net, bottom_layers=common_layers()) policy_and_value_optimizer_fun = get_optimizer_fun( FLAGS.learning_rate) else: policy_net_fun = functools.partial(ppo.policy_net, bottom_layers=common_layers()) value_net_fun = functools.partial(ppo.value_net, bottom_layers=common_layers()) policy_optimizer_fun = get_optimizer_fun( FLAGS.policy_only_learning_rate) value_optimizer_fun = get_optimizer_fun( FLAGS.value_only_learning_rate) ppo.training_loop( env=env, epochs=FLAGS.epochs, policy_net_fun=policy_net_fun, value_net_fun=value_net_fun, policy_and_value_net_fun=policy_and_value_net_fun, policy_optimizer_fun=policy_optimizer_fun, value_optimizer_fun=value_optimizer_fun, policy_and_value_optimizer_fun=policy_and_value_optimizer_fun, num_optimizer_steps=FLAGS.num_optimizer_steps, policy_only_num_optimizer_steps=FLAGS. policy_only_num_optimizer_steps, value_only_num_optimizer_steps=FLAGS. value_only_num_optimizer_steps, batch_size=FLAGS.batch_size, target_kl=FLAGS.target_kl, boundary=FLAGS.boundary, max_timestep=FLAGS.max_timestep, random_seed=FLAGS.random_seed, c1=FLAGS.value_coef, c2=FLAGS.entropy_coef)
def main(argv): del argv logging.set_verbosity(FLAGS.log_level) bottom_layers = common_stax_layers() if FLAGS.env_name == "Pong-v0": bottom_layers = [stax.Div(255.0), stax.Flatten(2)] + bottom_layers ppo.training_loop( env_name=FLAGS.env_name, epochs=FLAGS.epochs, policy_net_fun=functools.partial( ppo.policy_net, bottom_layers=bottom_layers), value_net_fun=functools.partial( ppo.value_net, bottom_layers=bottom_layers), batch_size=FLAGS.batch_size, boundary=FLAGS.boundary, random_seed=FLAGS.random_seed)
def test_training_loop(self): with self.tmp_dir() as output_dir: env = self.get_wrapped_env("CartPole-v0", 2) eval_env = self.get_wrapped_env("CartPole-v0", 2) num_epochs = 2 batch_size = 2 # Run the training loop. ppo.training_loop(env=env, eval_env=eval_env, epochs=num_epochs, policy_and_value_net_fun=functools.partial( ppo.policy_and_value_net, bottom_layers_fn=lambda: [layers.Dense(1)]), policy_and_value_optimizer_fun=ppo.optimizer_fun, batch_size=batch_size, num_optimizer_steps=1, output_dir=output_dir, random_seed=0)
def run_training_loop(): """Runs the training loop.""" logging.info("Starting the training loop.") policy_and_value_net_fn = functools.partial( ppo.policy_and_value_net, bottom_layers_fn=common_layers, two_towers=FLAGS.two_towers) policy_and_value_optimizer_fn = get_optimizer_fn(FLAGS.learning_rate) random_seed = None try: random_seed = int(FLAGS.random_seed) except Exception: # pylint: disable=broad-except pass ppo.training_loop( env=env, epochs=FLAGS.epochs, policy_and_value_net_fn=policy_and_value_net_fn, policy_and_value_optimizer_fn=policy_and_value_optimizer_fn, n_optimizer_steps=FLAGS.n_optimizer_steps, print_every_optimizer_steps=FLAGS.print_every_optimizer_steps, batch_size=FLAGS.batch_size, target_kl=FLAGS.target_kl, boundary=FLAGS.boundary, max_timestep=FLAGS.truncation_timestep, max_timestep_eval=FLAGS.truncation_timestep_eval, random_seed=random_seed, c1=FLAGS.value_coef, c2=FLAGS.entropy_coef, gamma=FLAGS.gamma, lambda_=FLAGS.lambda_, epsilon=FLAGS.epsilon, enable_early_stopping=FLAGS.enable_early_stopping, output_dir=FLAGS.output_dir, eval_every_n=FLAGS.eval_every_n, done_frac_for_policy_save=FLAGS.done_frac_for_policy_save, eval_env=eval_env, n_evals=FLAGS.n_evals, env_name=str(FLAGS.env_problem_name), len_history_for_policy=int(FLAGS.len_history_for_policy), )
def run_training_loop(): """Run the PPO training loop.""" policy_net_fun = None value_net_fun = None policy_and_value_net_fun = None policy_optimizer_fun = None value_optimizer_fun = None policy_and_value_optimizer_fun = None if FLAGS.combined_policy_and_value_function: policy_and_value_net_fun = functools.partial( ppo.policy_and_value_net, bottom_layers=common_layers()) policy_and_value_optimizer_fun = get_optimizer_fun( FLAGS.policy_and_value_net_learning_rate) else: policy_net_fun = functools.partial(ppo.policy_net, bottom_layers=common_layers()) value_net_fun = functools.partial(ppo.value_net, bottom_layers=common_layers()) policy_optimizer_fun = get_optimizer_fun( FLAGS.policy_net_learning_rate) value_optimizer_fun = get_optimizer_fun(FLAGS.value_net_learning_rate) ppo.training_loop( env_name=FLAGS.env_name, epochs=FLAGS.epochs, policy_net_fun=policy_net_fun, value_net_fun=value_net_fun, policy_and_value_net_fun=policy_and_value_net_fun, policy_optimizer_fun=policy_optimizer_fun, value_optimizer_fun=value_optimizer_fun, policy_and_value_optimizer_fun=policy_and_value_optimizer_fun, batch_size=FLAGS.batch_size, num_optimizer_steps=FLAGS.num_optimizer_steps, boundary=FLAGS.boundary, max_timestep=FLAGS.max_timestep, random_seed=FLAGS.random_seed)
def test_training_loop(self): env = gym.make("CartPole-v0") # Usually gym envs are wrapped in TimeLimit wrapper. env = gym_utils.remove_time_limit_wrapper(env) # Limit this to a small number for tests. env = gym.wrappers.TimeLimit(env, max_episode_steps=2) num_epochs = 2 batch_size = 2 _, rewards, val_losses, ppo_objectives = ppo.training_loop( env=env, epochs=num_epochs, batch_size=batch_size, num_optimizer_steps=1) self.assertLen(rewards, num_epochs) self.assertLen(val_losses, num_epochs) self.assertLen(ppo_objectives, num_epochs)
def test_training_loop_policy_and_value_function(self): env = self.get_wrapped_env("CartPole-v0", 2) num_epochs = 2 batch_size = 2 # Run the training loop. _, rewards, val_losses, ppo_objectives = ppo.training_loop( env=env, epochs=num_epochs, policy_and_value_net_fun=functools.partial( ppo.policy_and_value_net, bottom_layers=[layers.Dense(1)]), policy_and_value_optimizer_fun=ppo.optimizer_fun, batch_size=batch_size, num_optimizer_steps=1, random_seed=0) self.assertLen(rewards, num_epochs) self.assertLen(val_losses, num_epochs) self.assertLen(ppo_objectives, num_epochs)
def test_training_loop(self): env = gym.make("CartPole-v0") # Usually gym envs are wrapped in TimeLimit wrapper. env = gym_utils.remove_time_limit_wrapper(env) # Limit this to a small number for tests. env = gym.wrappers.TimeLimit(env, max_episode_steps=2) num_epochs = 2 batch_size = 2 # Run the training loop. _, rewards, val_losses, ppo_objectives = ppo.training_loop( env=env, epochs=num_epochs, policy_net_fun=functools.partial(ppo.policy_net, bottom_layers=[stax.Dense(1)]), value_net_fun=functools.partial(ppo.value_net, bottom_layers=[stax.Dense(1)]), batch_size=batch_size, num_optimizer_steps=1, random_seed=0) self.assertLen(rewards, num_epochs) self.assertLen(val_losses, num_epochs) self.assertLen(ppo_objectives, num_epochs)
def run_training_loop(): """Runs the training loop.""" policy_net_fun = None value_net_fun = None policy_and_value_net_fun = None policy_optimizer_fun = None value_optimizer_fun = None policy_and_value_optimizer_fun = None if FLAGS.combined_network: policy_and_value_net_fun = functools.partial( ppo.policy_and_value_net, bottom_layers_fn=common_layers, two_towers=FLAGS.two_towers) policy_and_value_optimizer_fun = get_optimizer_fun( FLAGS.learning_rate) else: policy_net_fun = functools.partial(ppo.policy_net, bottom_layers=common_layers()) value_net_fun = functools.partial(ppo.value_net, bottom_layers=common_layers()) policy_optimizer_fun = get_optimizer_fun( FLAGS.policy_only_learning_rate) value_optimizer_fun = get_optimizer_fun( FLAGS.value_only_learning_rate) random_seed = None try: random_seed = int(FLAGS.random_seed) except Exception: # pylint: disable=broad-except pass ppo.training_loop( env=env, epochs=FLAGS.epochs, policy_net_fun=policy_net_fun, value_net_fun=value_net_fun, policy_and_value_net_fun=policy_and_value_net_fun, policy_optimizer_fun=policy_optimizer_fun, value_optimizer_fun=value_optimizer_fun, policy_and_value_optimizer_fun=policy_and_value_optimizer_fun, num_optimizer_steps=FLAGS.num_optimizer_steps, policy_only_num_optimizer_steps=FLAGS. policy_only_num_optimizer_steps, value_only_num_optimizer_steps=FLAGS. value_only_num_optimizer_steps, print_every_optimizer_steps=FLAGS.print_every_optimizer_steps, batch_size=FLAGS.batch_size, target_kl=FLAGS.target_kl, boundary=FLAGS.boundary, max_timestep=FLAGS.truncation_timestep, max_timestep_eval=FLAGS.truncation_timestep_eval, random_seed=random_seed, c1=FLAGS.value_coef, c2=FLAGS.entropy_coef, gamma=FLAGS.gamma, lambda_=FLAGS.lambda_, epsilon=FLAGS.epsilon, enable_early_stopping=FLAGS.enable_early_stopping, output_dir=FLAGS.output_dir, eval_every_n=FLAGS.eval_every_n, eval_env=eval_env)
def main(unused_argv): logging.set_verbosity(FLAGS.log_level) ppo.training_loop(env_name=FLAGS.env, epochs=FLAGS.epochs)