Пример #1
0
def main(argv):
  del argv

  if FLAGS.enable_eager_execution:
    tf.enable_eager_execution()

  output_dir = FLAGS.output_dir

  # Initialize Gin.
  initialize_gin()

  output_dir = os.path.join(output_dir, str(FLAGS.replica))

  env_kwargs = {"output_dir": output_dir}

  env = env_problem_utils.make_env(
      batch_size=1,
      env_problem_name=FLAGS.env_problem_name,
      resize=FLAGS.resize,
      resized_height=FLAGS.resized_height,
      resized_width=FLAGS.resized_width,
      max_timestep=FLAGS.max_timestep,
      clip_rewards=FLAGS.clip_rewards,
      **env_kwargs)

  logging.info("Replica[%s] is ready to serve requests.", FLAGS.replica)
  server_utils.serve(output_dir, env, FLAGS.env_service_port)
Пример #2
0
def main(argv):
    del argv
    output_dir = FLAGS.output_dir

    output_dir = os.path.join(output_dir, str(FLAGS.replica))

    env = env_problem_utils.make_env(batch_size=1,
                                     env_problem_name=FLAGS.env_problem_name,
                                     resize=FLAGS.resize,
                                     resized_height=FLAGS.resized_height,
                                     resized_width=FLAGS.resized_width,
                                     max_timestep=FLAGS.max_timestep,
                                     clip_rewards=FLAGS.clip_rewards)

    logging.info("Replica[%s] is ready to serve requests.", FLAGS.replica)
    server_utils.serve(output_dir, env, FLAGS.env_service_port)
Пример #3
0
def main(argv):
    del argv

    if FLAGS.replicas == 0:
        env = client_env.ClientEnv(FLAGS.server_bns)
        pdb.set_trace()
        env.close()
        return

    # Replicated server.
    per_env_kwargs = [{
        "remote_env_address":
        os.path.join(FLAGS.server_bns, str(replica))
    } for replica in range(FLAGS.replicas)]
    env = env_problem_utils.make_env(batch_size=FLAGS.replicas,
                                     env_problem_name="ClientEnv-v0",
                                     resize=False,
                                     parallelism=FLAGS.replicas,
                                     per_env_kwargs=per_env_kwargs)

    pdb.set_trace()

    env.close()
Пример #4
0
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='Acrobot-v1',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize=False,
    resize_dims=(105, 80),
    trainer_class=rl_trainers.PPO,
    n_epochs=10000,
    trajectory_dump_dir=None,
    num_actions=None,
    light_rl=False,
    light_rl_trainer=light_trainers.RLTrainer,
):
  """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize: whether to do resize or not
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
    num_actions: None unless one wants to use the discretization wrapper. Then
      num_actions specifies the number of discrete actions.
    light_rl: whether to use the light RL setting (experimental).
    light_rl_trainer: whichh light RL trainer to use (experimental).
  """
  tf_np.set_allow_float64(FLAGS.tf_allow_float64)

  if light_rl:
    task = rl_task.RLTask()
    env_name = task.env_name


  if FLAGS.jax_debug_nans:
    config.update('jax_debug_nans', True)

  if FLAGS.use_tpu:
    config.update('jax_platform_name', 'tpu')
  else:
    config.update('jax_platform_name', '')


  if light_rl:
    trainer = light_rl_trainer(task=task, output_dir=output_dir)
    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
      math.disable_jit()
      with jax.disable_jit():
        trainer.run(n_epochs, n_epochs_is_total_epochs=True)
        trainer.close()
    else:
      trainer.run(n_epochs, n_epochs_is_total_epochs=True)
      trainer.close()
    return

  # TODO(pkozakowski): Find a better way to determine this.
  train_env_kwargs = {}
  eval_env_kwargs = {}
  if 'OnlineTuneEnv' in env_name:
    envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, 'envs')
    train_env_output_dir = os.path.join(envs_output_dir, 'train')
    eval_env_output_dir = os.path.join(envs_output_dir, 'eval')
    train_env_kwargs = {'output_dir': train_env_output_dir}
    eval_env_kwargs = {'output_dir': eval_env_output_dir}

  parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

  logging.info('Num discretized actions %s', num_actions)
  logging.info('Resize %d', resize)

  train_env = env_problem_utils.make_env(
      batch_size=train_batch_size,
      env_problem_name=env_name,
      rendered_env=rendered_env,
      resize=resize,
      resize_dims=resize_dims,
      max_timestep=max_timestep,
      clip_rewards=clip_rewards,
      parallelism=parallelism,
      use_tpu=FLAGS.use_tpu,
      num_actions=num_actions,
      **train_env_kwargs)
  assert train_env

  eval_env = env_problem_utils.make_env(
      batch_size=eval_batch_size,
      env_problem_name=env_name,
      rendered_env=rendered_env,
      resize=resize,
      resize_dims=resize_dims,
      max_timestep=max_timestep,
      clip_rewards=clip_rewards,
      parallelism=parallelism,
      use_tpu=FLAGS.use_tpu,
      num_actions=num_actions,
      **eval_env_kwargs)
  assert eval_env

  def run_training_loop():
    """Runs the training loop."""
    logging.info('Starting the training loop.')

    trainer = trainer_class(
        output_dir=output_dir,
        train_env=train_env,
        eval_env=eval_env,
        trajectory_dump_dir=trajectory_dump_dir,
        async_mode=FLAGS.async_mode,
    )
    trainer.training_loop(n_epochs=n_epochs)

  if FLAGS.jax_debug_nans or FLAGS.disable_jit:
    math.disable_jit()
    with jax.disable_jit():
      run_training_loop()
  else:
    run_training_loop()
Пример #5
0
    def __init__(
        self,
        output_dir,
        env_name='PongNoFrameskip-v4',
        env_kwargs=None,
        train_batch_size=16,
        eval_batch_size=16,
        trainer_class=ppo_trainer.PPO,
        action_multipliers=None,
        observation_metrics=(
            ('eval', 'eval/raw_reward_mean/temperature_1.0'),
            ('eval', 'eval/raw_reward_std/temperature_1.0'),
        ),
        include_controls_in_observation=False,
        reward_metric=('eval', 'eval/raw_reward_mean/temperature_1.0'),
        train_epochs=100,
        env_steps=100,
        # This is a tuple instead of a dict because the controls are
        # ordered in the action space.
        control_configs=(
            # (name, start, (low, high), flip)
            ('learning_rate', 1e-3, (1e-9, 10.0), False), ),
        observation_range=(0.0, 10.0),
        # Don't save checkpoints by default, as they tend to use a lot of
        # space.
        should_save_checkpoints=False,
        # Same here.
        should_write_summaries=False,
    ):
        if action_multipliers is None:
            action_multipliers = self.DEFAULT_ACTION_MULTIPLIERS
        if env_kwargs is None:
            env_kwargs = {}
        (train_env, eval_env) = tuple(
            env_problem_utils.make_env(  # pylint: disable=g-complex-comprehension
                env_problem_name=env_name,
                batch_size=batch_size,
                **env_kwargs)
            for batch_size in (train_batch_size, eval_batch_size))
        # Initialize Trainer in OnlineTuneRLEnv lazily to prevent long startup in
        # the async setup, where we just use the environments as containers for
        # trajectories.
        self._trainer_fn = functools.partial(
            trainer_class,
            train_env=train_env,
            eval_env=eval_env,
            controller=(lambda history: lambda step: self._current_controls),
            should_save_checkpoints=should_save_checkpoints,
            should_write_summaries=should_write_summaries,
        )
        self._trainer = None
        self._action_multipliers = action_multipliers
        self._observation_metrics = observation_metrics
        self._include_controls_in_observation = include_controls_in_observation
        self._reward_metric = reward_metric
        self._train_epochs = train_epochs
        self._env_steps = env_steps
        self._control_configs = control_configs
        self._observation_range = observation_range

        self._output_dir = output_dir
        gfile.makedirs(self._output_dir)
        # Actions are indices in self._action_multipliers.
        self.action_space = gym.spaces.MultiDiscrete(
            [len(self._action_multipliers)] * len(self._control_configs))
        # Observation is a vector with the values of the metrics specified in
        # observation_metrics plus optionally the current controls.
        observation_dim = (len(self._observation_metrics) +
                           int(self._include_controls_in_observation) *
                           len(self._control_configs))

        (obs_low, obs_high) = observation_range
        self.observation_space = gym.spaces.Box(
            # Observations are clipped to this range.
            low=obs_low,
            high=obs_high,
            shape=(observation_dim, ),
        )
Пример #6
0
def train_rl(
    output_dir,
    train_batch_size,
    eval_batch_size,
    env_name='ClientEnv-v0',
    max_timestep=None,
    clip_rewards=False,
    rendered_env=False,
    resize_dims=(105, 80),
    trainer_class=rl_trainers.PPO,
    n_epochs=10000,
    trajectory_dump_dir=None,
):
  """Train the RL agent.

  Args:
    output_dir: Output directory.
    train_batch_size: Number of parallel environments to use for training.
    eval_batch_size: Number of parallel environments to use for evaluation.
    env_name: Name of the environment.
    max_timestep: Int or None, the maximum number of timesteps in a trajectory.
      The environment is wrapped in a TimeLimit wrapper.
    clip_rewards: Whether to clip and discretize the rewards.
    rendered_env: Whether the environment has visual input. If so, a
      RenderedEnvProblem will be used.
    resize_dims: Pair (height, width), dimensions to resize the visual
      observations to.
    trainer_class: RLTrainer class to use.
    n_epochs: Number epochs to run the training for.
    trajectory_dump_dir: Directory to dump trajectories to.
  """

  if FLAGS.jax_debug_nans:
    config.update('jax_debug_nans', True)

  if FLAGS.use_tpu:
    config.update('jax_platform_name', 'tpu')
  else:
    config.update('jax_platform_name', 'gpu')


  # TODO(pkozakowski): Find a better way to determine this.
  train_env_kwargs = {}
  eval_env_kwargs = {}
  if 'OnlineTuneEnv' in env_name:
    envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, 'envs')
    train_env_output_dir = os.path.join(envs_output_dir, 'train')
    eval_env_output_dir = os.path.join(envs_output_dir, 'eval')
    train_env_kwargs = {'output_dir': train_env_output_dir}
    eval_env_kwargs = {'output_dir': eval_env_output_dir}

  if 'ClientEnv' in env_name:
    train_env_kwargs['per_env_kwargs'] = [{
        'remote_env_address': os.path.join(FLAGS.train_server_bns, str(replica))
    } for replica in range(train_batch_size)]

    eval_env_kwargs['per_env_kwargs'] = [{
        'remote_env_address': os.path.join(FLAGS.eval_server_bns, str(replica))
    } for replica in range(eval_batch_size)]

  # TODO(afrozm): Should we leave out some cores?
  parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1

  train_env = env_problem_utils.make_env(
      batch_size=train_batch_size,
      env_problem_name=env_name,
      resize=rendered_env,
      resize_dims=resize_dims,
      max_timestep=max_timestep,
      clip_rewards=clip_rewards,
      parallelism=parallelism,
      use_tpu=FLAGS.use_tpu,
      **train_env_kwargs)
  assert train_env

  eval_env = env_problem_utils.make_env(
      batch_size=eval_batch_size,
      env_problem_name=env_name,
      resize=rendered_env,
      resize_dims=resize_dims,
      max_timestep=max_timestep,
      clip_rewards=clip_rewards,
      parallelism=parallelism,
      use_tpu=FLAGS.use_tpu,
      **eval_env_kwargs)
  assert eval_env

  def run_training_loop():
    """Runs the training loop."""
    logging.info('Starting the training loop.')

    trainer = trainer_class(
        output_dir=output_dir,
        train_env=train_env,
        eval_env=eval_env,
        trajectory_dump_dir=trajectory_dump_dir,
        async_mode=FLAGS.async_mode,
    )
    trainer.training_loop(n_epochs=n_epochs)

  if FLAGS.jax_debug_nans or FLAGS.disable_jit:
    with jax.disable_jit():
      run_training_loop()
  else:
    run_training_loop()
Пример #7
0
def create_envs_and_collect_trajectories(
        output_dir,
        env_name="OnlineTuneEnv-v0",
        max_timestep=None,
        clip_rewards=False,
        rendered_env=False,
        resize_dims=(105, 80),
):
    """Creates the envs and continuously collects trajectories."""

    train_batch_size = 1
    eval_batch_size = 1

    # TODO(pkozakowski): Find a better way to determine this.
    train_env_kwargs = {}
    eval_env_kwargs = {}
    if "OnlineTuneEnv" in env_name:
        # TODO(pkozakowski): Separate env output dirs by train/eval and epoch.
        train_env_kwargs = {
            "output_dir": os.path.join(output_dir, "envs/train")
        }
        eval_env_kwargs = {"output_dir": os.path.join(output_dir, "envs/eval")}

    if "ClientEnv" in env_name:
        train_env_kwargs["per_env_kwargs"] = [{
            "remote_env_address":
            os.path.join(FLAGS.train_server_bns, str(replica))
        } for replica in range(train_batch_size)]

        eval_env_kwargs["per_env_kwargs"] = [{
            "remote_env_address":
            os.path.join(FLAGS.eval_server_bns, str(replica))
        } for replica in range(eval_batch_size)]

    parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1
    train_parallelism = min(train_batch_size, parallelism)
    eval_parallelism = min(eval_batch_size, parallelism)

    train_env = env_problem_utils.make_env(batch_size=train_batch_size,
                                           env_problem_name=env_name,
                                           resize=rendered_env,
                                           resize_dims=resize_dims,
                                           max_timestep=max_timestep,
                                           clip_rewards=clip_rewards,
                                           parallelism=train_parallelism,
                                           use_tpu=FLAGS.use_tpu,
                                           **train_env_kwargs)
    assert train_env

    eval_env = env_problem_utils.make_env(batch_size=eval_batch_size,
                                          env_problem_name=env_name,
                                          resize=rendered_env,
                                          resize_dims=resize_dims,
                                          max_timestep=max_timestep,
                                          clip_rewards=clip_rewards,
                                          parallelism=eval_parallelism,
                                          use_tpu=FLAGS.use_tpu,
                                          **eval_env_kwargs)
    assert eval_env

    def run_collect_loop():
        async_lib.continuously_collect_trajectories(
            output_dir,
            train_env,
            eval_env,
            trajectory_dump_dir=None,
            env_id=FLAGS.replica,
            try_abort=FLAGS.try_abort,
            max_trajectories_to_collect=(
                None if FLAGS.max_trajectories_to_collect < 0 else
                FLAGS.max_trajectories_to_collect))

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        with jax.disable_jit():
            run_collect_loop()
    else:
        run_collect_loop()
Пример #8
0
def create_envs_and_collect_trajectories(
        output_dir,
        env_name='OnlineTuneEnv-v0',
        max_timestep=None,
        clip_rewards=False,
        rendered_env=False,
        resize_dims=(105, 80),
):
    """Creates the envs and continuously collects trajectories."""

    train_batch_size = 1
    eval_batch_size = 1

    # TODO(pkozakowski): Find a better way to determine this.
    train_env_kwargs = {}
    eval_env_kwargs = {}
    if 'OnlineTuneEnv' in env_name:
        envs_output_dir = FLAGS.envs_output_dir or os.path.join(
            output_dir, 'envs')
        train_env_output_dir = os.path.join(envs_output_dir, 'train')
        eval_env_output_dir = os.path.join(envs_output_dir, 'eval')
        train_env_kwargs = {'output_dir': train_env_output_dir}
        eval_env_kwargs = {'output_dir': eval_env_output_dir}

    parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1
    train_parallelism = min(train_batch_size, parallelism)
    eval_parallelism = min(eval_batch_size, parallelism)

    train_env = env_problem_utils.make_env(batch_size=train_batch_size,
                                           env_problem_name=env_name,
                                           resize=rendered_env,
                                           resize_dims=resize_dims,
                                           max_timestep=max_timestep,
                                           clip_rewards=clip_rewards,
                                           parallelism=train_parallelism,
                                           use_tpu=FLAGS.use_tpu,
                                           **train_env_kwargs)
    assert train_env

    eval_env = env_problem_utils.make_env(batch_size=eval_batch_size,
                                          env_problem_name=env_name,
                                          resize=rendered_env,
                                          resize_dims=resize_dims,
                                          max_timestep=max_timestep,
                                          clip_rewards=clip_rewards,
                                          parallelism=eval_parallelism,
                                          use_tpu=FLAGS.use_tpu,
                                          **eval_env_kwargs)
    assert eval_env

    def run_collect_loop():
        async_lib.continuously_collect_trajectories(
            output_dir,
            train_env,
            eval_env,
            trajectory_dump_dir=None,
            env_id=FLAGS.replica,
            try_abort=FLAGS.try_abort,
            max_trajectories_to_collect=(
                None if FLAGS.max_trajectories_to_collect < 0 else
                FLAGS.max_trajectories_to_collect))

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        with jax.disable_jit():
            run_collect_loop()
    else:
        run_collect_loop()