Пример #1
0
def train_rl(output_dir,
             n_epochs=10000,
             light_rl=True,
             light_rl_trainer=light_trainers.PolicyGradient):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    n_epochs: Number epochs to run the training for.
    light_rl: deprecated, always True, left out for old gin configs.
    light_rl_trainer: which light RL trainer to use (experimental).
  """
    del light_rl
    tf_np.set_allow_float64(FLAGS.tf_allow_float64)
    task = rl_task.RLTask()
    env_name = task.env_name

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', '')

    trainer = light_rl_trainer(task=task, output_dir=output_dir)

    def light_training_loop():
        """Run the trainer for n_epochs and call close on it."""
        try:
            logging.info('Starting RL training for %d epochs.', n_epochs)
            trainer.run(n_epochs, n_epochs_is_total_epochs=True)
            logging.info('Completed RL training for %d epochs.', n_epochs)
            trainer.close()
            logging.info('Trainer is now closed.')
        except Exception as e:
            raise e
        finally:
            logging.info(
                'Encountered an exception, still calling trainer.close()')
            trainer.close()
            logging.info('Trainer is now closed.')

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        fastmath.disable_jit()
        with jax.disable_jit():
            light_training_loop()
    else:
        light_training_loop()
Пример #2
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    if FLAGS.disable_jit:
        fastmath.disable_jit()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP):
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Пример #3
0
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    _tf_setup_from_flags()
    _gin_parse_configs()
    _jax_and_tf_configure_for_devices()

    # Create a JAX GPU cluster if using JAX and given a chief IP.
    if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip:
        _make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id,
                              FLAGS.gpu_cluster_chief_ip,
                              FLAGS.gpu_cluster_n_hosts,
                              FLAGS.gpu_cluster_port)

    if FLAGS.disable_jit:
        fastmath.disable_jit()

    output_dir = _output_dir_or_default()
    if FLAGS.use_tpu and fastmath.is_backend(Backend.TFNP):
        _train_using_tf(output_dir)
    else:
        trainer_lib.train(output_dir=output_dir)

    trainer_lib.log('Finished training.')
Пример #4
0
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='Acrobot-v1',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize=False,
    resize_dims=(105, 80),
    trainer_class=None,
    n_epochs=10000,
    trajectory_dump_dir=None,
    num_actions=None,
    light_rl=True,
    light_rl_trainer=light_trainers.PolicyGradient,
):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize: whether to do resize or not
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
    num_actions: None unless one wants to use the discretization wrapper. Then
      num_actions specifies the number of discrete actions.
    light_rl: whether to use the light RL setting (experimental).
    light_rl_trainer: whichh light RL trainer to use (experimental).
  """
    tf_np.set_allow_float64(FLAGS.tf_allow_float64)

    if light_rl:
        task = rl_task.RLTask()
        env_name = task.env_name
    else:
        # TODO(lukaszkaiser): remove the name light and all references.
        # It was kept for now to make sure all regression tests pass first,
        # so that if we need to revert we save some work.
        raise ValueError('Non-light RL is deprecated.')

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', '')

    if light_rl:
        trainer = light_rl_trainer(task=task, output_dir=output_dir)

        def light_training_loop():
            """Run the trainer for n_epochs and call close on it."""
            try:
                logging.info('Starting RL training for %d epochs.', n_epochs)
                trainer.run(n_epochs, n_epochs_is_total_epochs=True)
                logging.info('Completed RL training for %d epochs.', n_epochs)
                trainer.close()
                logging.info('Trainer is now closed.')
            except Exception as e:
                raise e
            finally:
                logging.info(
                    'Encountered an exception, still calling trainer.close()')
                trainer.close()
                logging.info('Trainer is now closed.')

        if FLAGS.jax_debug_nans or FLAGS.disable_jit:
            fastmath.disable_jit()
            with jax.disable_jit():
                light_training_loop()
        else:
            light_training_loop()
        return

    # TODO(pkozakowski): Find a better way to determine this.
    train_env_kwargs = {}
    eval_env_kwargs = {}
    if 'OnlineTuneEnv' in env_name:
        envs_output_dir = FLAGS.envs_output_dir or os.path.join(
            output_dir, 'envs')
        train_env_output_dir = os.path.join(envs_output_dir, 'train')
        eval_env_output_dir = os.path.join(envs_output_dir, 'eval')
        train_env_kwargs = {'output_dir': train_env_output_dir}
        eval_env_kwargs = {'output_dir': eval_env_output_dir}

    parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

    logging.info('Num discretized actions %s', num_actions)
    logging.info('Resize %d', resize)

    train_env = env_problem_utils.make_env(batch_size=train_batch_size,
                                           env_problem_name=env_name,
                                           rendered_env=rendered_env,
                                           resize=resize,
                                           resize_dims=resize_dims,
                                           max_timestep=max_timestep,
                                           clip_rewards=clip_rewards,
                                           parallelism=parallelism,
                                           use_tpu=FLAGS.use_tpu,
                                           num_actions=num_actions,
                                           **train_env_kwargs)
    assert train_env

    eval_env = env_problem_utils.make_env(batch_size=eval_batch_size,
                                          env_problem_name=env_name,
                                          rendered_env=rendered_env,
                                          resize=resize,
                                          resize_dims=resize_dims,
                                          max_timestep=max_timestep,
                                          clip_rewards=clip_rewards,
                                          parallelism=parallelism,
                                          use_tpu=FLAGS.use_tpu,
                                          num_actions=num_actions,
                                          **eval_env_kwargs)
    assert eval_env

    def run_training_loop():
        """Runs the training loop."""
        logging.info('Starting the training loop.')

        trainer = trainer_class(
            output_dir=output_dir,
            train_env=train_env,
            eval_env=eval_env,
            trajectory_dump_dir=trajectory_dump_dir,
            async_mode=FLAGS.async_mode,
        )
        trainer.training_loop(n_epochs=n_epochs)

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        fastmath.disable_jit()
        with jax.disable_jit():
            run_training_loop()
    else:
        run_training_loop()
Пример #5
0
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='Acrobot-v1',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize=False,
    resize_dims=(105, 80),
    trainer_class=None,
    n_epochs=10000,
    trajectory_dump_dir=None,
    num_actions=None,
    light_rl=True,
    light_rl_trainer=light_trainers.PolicyGradient,
):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize: whether to do resize or not
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
    num_actions: None unless one wants to use the discretization wrapper. Then
      num_actions specifies the number of discrete actions.
    light_rl: deprecated, always True, left out for old gin configs.
    light_rl_trainer: which light RL trainer to use (experimental).
  """
    del light_rl
    tf_np.set_allow_float64(FLAGS.tf_allow_float64)
    task = rl_task.RLTask()
    env_name = task.env_name

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', '')

    trainer = light_rl_trainer(task=task, output_dir=output_dir)

    def light_training_loop():
        """Run the trainer for n_epochs and call close on it."""
        try:
            logging.info('Starting RL training for %d epochs.', n_epochs)
            trainer.run(n_epochs, n_epochs_is_total_epochs=True)
            logging.info('Completed RL training for %d epochs.', n_epochs)
            trainer.close()
            logging.info('Trainer is now closed.')
        except Exception as e:
            raise e
        finally:
            logging.info(
                'Encountered an exception, still calling trainer.close()')
            trainer.close()
            logging.info('Trainer is now closed.')

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        fastmath.disable_jit()
        with jax.disable_jit():
            light_training_loop()
    else:
        light_training_loop()