Esempio n. 1
0
File: trainer.py Progetto: MLDL/trax
def _tf_setup_from_flags():
  """Processes TensorFlow-relevant flags."""
  if FLAGS.enable_eager_execution:
    tf.compat.v1.enable_eager_execution()
  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)
    fastmath.tf_math.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)
  tf.config.optimizer.set_experimental_options({
      'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host,
      'layout_optimizer': FLAGS.tf_opt_layout,
  })
  tf_np.set_allow_float64(FLAGS.tf_allow_float64)
Esempio n. 2
0
def train_rl(output_dir,
             n_epochs=10000,
             light_rl=True,
             light_rl_trainer=light_trainers.PolicyGradient):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    n_epochs: Number epochs to run the training for.
    light_rl: deprecated, always True, left out for old gin configs.
    light_rl_trainer: which light RL trainer to use (experimental).
  """
    del light_rl
    tf_np.set_allow_float64(FLAGS.tf_allow_float64)
    task = rl_task.RLTask()
    env_name = task.env_name

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', '')

    trainer = light_rl_trainer(task=task, output_dir=output_dir)

    def light_training_loop():
        """Run the trainer for n_epochs and call close on it."""
        try:
            logging.info('Starting RL training for %d epochs.', n_epochs)
            trainer.run(n_epochs, n_epochs_is_total_epochs=True)
            logging.info('Completed RL training for %d epochs.', n_epochs)
            trainer.close()
            logging.info('Trainer is now closed.')
        except Exception as e:
            raise e
        finally:
            logging.info(
                'Encountered an exception, still calling trainer.close()')
            trainer.close()
            logging.info('Trainer is now closed.')

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        fastmath.disable_jit()
        with jax.disable_jit():
            light_training_loop()
    else:
        light_training_loop()
Esempio n. 3
0
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='Acrobot-v1',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize=False,
    resize_dims=(105, 80),
    trainer_class=rl_trainers.PPO,
    n_epochs=10000,
    trajectory_dump_dir=None,
    num_actions=None,
    light_rl=False,
    light_rl_trainer=light_trainers.RLTrainer,
):
  """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize: whether to do resize or not
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
    num_actions: None unless one wants to use the discretization wrapper. Then
      num_actions specifies the number of discrete actions.
    light_rl: whether to use the light RL setting (experimental).
    light_rl_trainer: whichh light RL trainer to use (experimental).
  """
  tf_np.set_allow_float64(FLAGS.tf_allow_float64)

  if light_rl:
    task = rl_task.RLTask()
    env_name = task.env_name


  if FLAGS.jax_debug_nans:
    config.update('jax_debug_nans', True)

  if FLAGS.use_tpu:
    config.update('jax_platform_name', 'tpu')
  else:
    config.update('jax_platform_name', '')


  if light_rl:
    trainer = light_rl_trainer(task=task, output_dir=output_dir)
    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
      math.disable_jit()
      with jax.disable_jit():
        trainer.run(n_epochs, n_epochs_is_total_epochs=True)
        trainer.close()
    else:
      trainer.run(n_epochs, n_epochs_is_total_epochs=True)
      trainer.close()
    return

  # TODO(pkozakowski): Find a better way to determine this.
  train_env_kwargs = {}
  eval_env_kwargs = {}
  if 'OnlineTuneEnv' in env_name:
    envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, 'envs')
    train_env_output_dir = os.path.join(envs_output_dir, 'train')
    eval_env_output_dir = os.path.join(envs_output_dir, 'eval')
    train_env_kwargs = {'output_dir': train_env_output_dir}
    eval_env_kwargs = {'output_dir': eval_env_output_dir}

  parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

  logging.info('Num discretized actions %s', num_actions)
  logging.info('Resize %d', resize)

  train_env = env_problem_utils.make_env(
      batch_size=train_batch_size,
      env_problem_name=env_name,
      rendered_env=rendered_env,
      resize=resize,
      resize_dims=resize_dims,
      max_timestep=max_timestep,
      clip_rewards=clip_rewards,
      parallelism=parallelism,
      use_tpu=FLAGS.use_tpu,
      num_actions=num_actions,
      **train_env_kwargs)
  assert train_env

  eval_env = env_problem_utils.make_env(
      batch_size=eval_batch_size,
      env_problem_name=env_name,
      rendered_env=rendered_env,
      resize=resize,
      resize_dims=resize_dims,
      max_timestep=max_timestep,
      clip_rewards=clip_rewards,
      parallelism=parallelism,
      use_tpu=FLAGS.use_tpu,
      num_actions=num_actions,
      **eval_env_kwargs)
  assert eval_env

  def run_training_loop():
    """Runs the training loop."""
    logging.info('Starting the training loop.')

    trainer = trainer_class(
        output_dir=output_dir,
        train_env=train_env,
        eval_env=eval_env,
        trajectory_dump_dir=trajectory_dump_dir,
        async_mode=FLAGS.async_mode,
    )
    trainer.training_loop(n_epochs=n_epochs)

  if FLAGS.jax_debug_nans or FLAGS.disable_jit:
    math.disable_jit()
    with jax.disable_jit():
      run_training_loop()
  else:
    run_training_loop()
Esempio n. 4
0
 def tearDown(self):
     tf_np.set_allow_float64(self._old_is_allow_float64)
     super().tearDown()
Esempio n. 5
0
 def setUp(self):
     super().setUp()
     test_utils.ensure_flag('test_tmpdir')
     self._old_is_allow_float64 = tf_np.is_allow_float64()
     tf_np.set_allow_float64(False)
Esempio n. 6
0
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='Acrobot-v1',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize=False,
    resize_dims=(105, 80),
    trainer_class=None,
    n_epochs=10000,
    trajectory_dump_dir=None,
    num_actions=None,
    light_rl=True,
    light_rl_trainer=light_trainers.PolicyGradient,
):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize: whether to do resize or not
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
    num_actions: None unless one wants to use the discretization wrapper. Then
      num_actions specifies the number of discrete actions.
    light_rl: deprecated, always True, left out for old gin configs.
    light_rl_trainer: which light RL trainer to use (experimental).
  """
    del light_rl
    tf_np.set_allow_float64(FLAGS.tf_allow_float64)
    task = rl_task.RLTask()
    env_name = task.env_name

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', '')

    trainer = light_rl_trainer(task=task, output_dir=output_dir)

    def light_training_loop():
        """Run the trainer for n_epochs and call close on it."""
        try:
            logging.info('Starting RL training for %d epochs.', n_epochs)
            trainer.run(n_epochs, n_epochs_is_total_epochs=True)
            logging.info('Completed RL training for %d epochs.', n_epochs)
            trainer.close()
            logging.info('Trainer is now closed.')
        except Exception as e:
            raise e
        finally:
            logging.info(
                'Encountered an exception, still calling trainer.close()')
            trainer.close()
            logging.info('Trainer is now closed.')

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        fastmath.disable_jit()
        with jax.disable_jit():
            light_training_loop()
    else:
        light_training_loop()
Esempio n. 7
0
def set_tf_allow_float64(b):
    tf_np.set_allow_float64(b)