def _tf_setup_from_flags(): """Processes TensorFlow-relevant flags.""" if FLAGS.enable_eager_execution: tf.compat.v1.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) fastmath.tf_math.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) tf.config.optimizer.set_experimental_options({ 'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host, 'layout_optimizer': FLAGS.tf_opt_layout, }) tf_np.set_allow_float64(FLAGS.tf_allow_float64)
def train_rl(output_dir, n_epochs=10000, light_rl=True, light_rl_trainer=light_trainers.PolicyGradient): """Train the RL agent. Args: output_dir: Output directory. n_epochs: Number epochs to run the training for. light_rl: deprecated, always True, left out for old gin configs. light_rl_trainer: which light RL trainer to use (experimental). """ del light_rl tf_np.set_allow_float64(FLAGS.tf_allow_float64) task = rl_task.RLTask() env_name = task.env_name if FLAGS.jax_debug_nans: config.update('jax_debug_nans', True) if FLAGS.use_tpu: config.update('jax_platform_name', 'tpu') else: config.update('jax_platform_name', '') trainer = light_rl_trainer(task=task, output_dir=output_dir) def light_training_loop(): """Run the trainer for n_epochs and call close on it.""" try: logging.info('Starting RL training for %d epochs.', n_epochs) trainer.run(n_epochs, n_epochs_is_total_epochs=True) logging.info('Completed RL training for %d epochs.', n_epochs) trainer.close() logging.info('Trainer is now closed.') except Exception as e: raise e finally: logging.info( 'Encountered an exception, still calling trainer.close()') trainer.close() logging.info('Trainer is now closed.') if FLAGS.jax_debug_nans or FLAGS.disable_jit: fastmath.disable_jit() with jax.disable_jit(): light_training_loop() else: light_training_loop()
def train_rl( output_dir, train_batch_size, eval_batch_size, env_name='Acrobot-v1', max_timestep=None, clip_rewards=False, rendered_env=False, resize=False, resize_dims=(105, 80), trainer_class=rl_trainers.PPO, n_epochs=10000, trajectory_dump_dir=None, num_actions=None, light_rl=False, light_rl_trainer=light_trainers.RLTrainer, ): """Train the RL agent. Args: output_dir: Output directory. train_batch_size: Number of parallel environments to use for training. eval_batch_size: Number of parallel environments to use for evaluation. env_name: Name of the environment. max_timestep: Int or None, the maximum number of timesteps in a trajectory. The environment is wrapped in a TimeLimit wrapper. clip_rewards: Whether to clip and discretize the rewards. rendered_env: Whether the environment has visual input. If so, a RenderedEnvProblem will be used. resize: whether to do resize or not resize_dims: Pair (height, width), dimensions to resize the visual observations to. trainer_class: RLTrainer class to use. n_epochs: Number epochs to run the training for. trajectory_dump_dir: Directory to dump trajectories to. num_actions: None unless one wants to use the discretization wrapper. Then num_actions specifies the number of discrete actions. light_rl: whether to use the light RL setting (experimental). light_rl_trainer: whichh light RL trainer to use (experimental). """ tf_np.set_allow_float64(FLAGS.tf_allow_float64) if light_rl: task = rl_task.RLTask() env_name = task.env_name if FLAGS.jax_debug_nans: config.update('jax_debug_nans', True) if FLAGS.use_tpu: config.update('jax_platform_name', 'tpu') else: config.update('jax_platform_name', '') if light_rl: trainer = light_rl_trainer(task=task, output_dir=output_dir) if FLAGS.jax_debug_nans or FLAGS.disable_jit: math.disable_jit() with jax.disable_jit(): trainer.run(n_epochs, n_epochs_is_total_epochs=True) trainer.close() else: trainer.run(n_epochs, n_epochs_is_total_epochs=True) trainer.close() return # TODO(pkozakowski): Find a better way to determine this. train_env_kwargs = {} eval_env_kwargs = {} if 'OnlineTuneEnv' in env_name: envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, 'envs') train_env_output_dir = os.path.join(envs_output_dir, 'train') eval_env_output_dir = os.path.join(envs_output_dir, 'eval') train_env_kwargs = {'output_dir': train_env_output_dir} eval_env_kwargs = {'output_dir': eval_env_output_dir} parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1 logging.info('Num discretized actions %s', num_actions) logging.info('Resize %d', resize) train_env = env_problem_utils.make_env( batch_size=train_batch_size, env_problem_name=env_name, rendered_env=rendered_env, resize=resize, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=parallelism, use_tpu=FLAGS.use_tpu, num_actions=num_actions, **train_env_kwargs) assert train_env eval_env = env_problem_utils.make_env( batch_size=eval_batch_size, env_problem_name=env_name, rendered_env=rendered_env, resize=resize, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=parallelism, use_tpu=FLAGS.use_tpu, num_actions=num_actions, **eval_env_kwargs) assert eval_env def run_training_loop(): """Runs the training loop.""" logging.info('Starting the training loop.') trainer = trainer_class( output_dir=output_dir, train_env=train_env, eval_env=eval_env, trajectory_dump_dir=trajectory_dump_dir, async_mode=FLAGS.async_mode, ) trainer.training_loop(n_epochs=n_epochs) if FLAGS.jax_debug_nans or FLAGS.disable_jit: math.disable_jit() with jax.disable_jit(): run_training_loop() else: run_training_loop()
def tearDown(self): tf_np.set_allow_float64(self._old_is_allow_float64) super().tearDown()
def setUp(self): super().setUp() test_utils.ensure_flag('test_tmpdir') self._old_is_allow_float64 = tf_np.is_allow_float64() tf_np.set_allow_float64(False)
def train_rl( output_dir, train_batch_size, eval_batch_size, env_name='Acrobot-v1', max_timestep=None, clip_rewards=False, rendered_env=False, resize=False, resize_dims=(105, 80), trainer_class=None, n_epochs=10000, trajectory_dump_dir=None, num_actions=None, light_rl=True, light_rl_trainer=light_trainers.PolicyGradient, ): """Train the RL agent. Args: output_dir: Output directory. train_batch_size: Number of parallel environments to use for training. eval_batch_size: Number of parallel environments to use for evaluation. env_name: Name of the environment. max_timestep: Int or None, the maximum number of timesteps in a trajectory. The environment is wrapped in a TimeLimit wrapper. clip_rewards: Whether to clip and discretize the rewards. rendered_env: Whether the environment has visual input. If so, a RenderedEnvProblem will be used. resize: whether to do resize or not resize_dims: Pair (height, width), dimensions to resize the visual observations to. trainer_class: RLTrainer class to use. n_epochs: Number epochs to run the training for. trajectory_dump_dir: Directory to dump trajectories to. num_actions: None unless one wants to use the discretization wrapper. Then num_actions specifies the number of discrete actions. light_rl: deprecated, always True, left out for old gin configs. light_rl_trainer: which light RL trainer to use (experimental). """ del light_rl tf_np.set_allow_float64(FLAGS.tf_allow_float64) task = rl_task.RLTask() env_name = task.env_name if FLAGS.jax_debug_nans: config.update('jax_debug_nans', True) if FLAGS.use_tpu: config.update('jax_platform_name', 'tpu') else: config.update('jax_platform_name', '') trainer = light_rl_trainer(task=task, output_dir=output_dir) def light_training_loop(): """Run the trainer for n_epochs and call close on it.""" try: logging.info('Starting RL training for %d epochs.', n_epochs) trainer.run(n_epochs, n_epochs_is_total_epochs=True) logging.info('Completed RL training for %d epochs.', n_epochs) trainer.close() logging.info('Trainer is now closed.') except Exception as e: raise e finally: logging.info( 'Encountered an exception, still calling trainer.close()') trainer.close() logging.info('Trainer is now closed.') if FLAGS.jax_debug_nans or FLAGS.disable_jit: fastmath.disable_jit() with jax.disable_jit(): light_training_loop() else: light_training_loop()
def set_tf_allow_float64(b): tf_np.set_allow_float64(b)