def main(argv): del argv if FLAGS.enable_eager_execution: tf.enable_eager_execution() output_dir = FLAGS.output_dir # Initialize Gin. initialize_gin() output_dir = os.path.join(output_dir, str(FLAGS.replica)) env_kwargs = {"output_dir": output_dir} env = env_problem_utils.make_env( batch_size=1, env_problem_name=FLAGS.env_problem_name, resize=FLAGS.resize, resized_height=FLAGS.resized_height, resized_width=FLAGS.resized_width, max_timestep=FLAGS.max_timestep, clip_rewards=FLAGS.clip_rewards, **env_kwargs) logging.info("Replica[%s] is ready to serve requests.", FLAGS.replica) server_utils.serve(output_dir, env, FLAGS.env_service_port)
def main(argv): del argv output_dir = FLAGS.output_dir output_dir = os.path.join(output_dir, str(FLAGS.replica)) env = env_problem_utils.make_env(batch_size=1, env_problem_name=FLAGS.env_problem_name, resize=FLAGS.resize, resized_height=FLAGS.resized_height, resized_width=FLAGS.resized_width, max_timestep=FLAGS.max_timestep, clip_rewards=FLAGS.clip_rewards) logging.info("Replica[%s] is ready to serve requests.", FLAGS.replica) server_utils.serve(output_dir, env, FLAGS.env_service_port)
def main(argv): del argv if FLAGS.replicas == 0: env = client_env.ClientEnv(FLAGS.server_bns) pdb.set_trace() env.close() return # Replicated server. per_env_kwargs = [{ "remote_env_address": os.path.join(FLAGS.server_bns, str(replica)) } for replica in range(FLAGS.replicas)] env = env_problem_utils.make_env(batch_size=FLAGS.replicas, env_problem_name="ClientEnv-v0", resize=False, parallelism=FLAGS.replicas, per_env_kwargs=per_env_kwargs) pdb.set_trace() env.close()
def train_rl( output_dir, train_batch_size, eval_batch_size, env_name='Acrobot-v1', max_timestep=None, clip_rewards=False, rendered_env=False, resize=False, resize_dims=(105, 80), trainer_class=rl_trainers.PPO, n_epochs=10000, trajectory_dump_dir=None, num_actions=None, light_rl=False, light_rl_trainer=light_trainers.RLTrainer, ): """Train the RL agent. Args: output_dir: Output directory. train_batch_size: Number of parallel environments to use for training. eval_batch_size: Number of parallel environments to use for evaluation. env_name: Name of the environment. max_timestep: Int or None, the maximum number of timesteps in a trajectory. The environment is wrapped in a TimeLimit wrapper. clip_rewards: Whether to clip and discretize the rewards. rendered_env: Whether the environment has visual input. If so, a RenderedEnvProblem will be used. resize: whether to do resize or not resize_dims: Pair (height, width), dimensions to resize the visual observations to. trainer_class: RLTrainer class to use. n_epochs: Number epochs to run the training for. trajectory_dump_dir: Directory to dump trajectories to. num_actions: None unless one wants to use the discretization wrapper. Then num_actions specifies the number of discrete actions. light_rl: whether to use the light RL setting (experimental). light_rl_trainer: whichh light RL trainer to use (experimental). """ tf_np.set_allow_float64(FLAGS.tf_allow_float64) if light_rl: task = rl_task.RLTask() env_name = task.env_name if FLAGS.jax_debug_nans: config.update('jax_debug_nans', True) if FLAGS.use_tpu: config.update('jax_platform_name', 'tpu') else: config.update('jax_platform_name', '') if light_rl: trainer = light_rl_trainer(task=task, output_dir=output_dir) if FLAGS.jax_debug_nans or FLAGS.disable_jit: math.disable_jit() with jax.disable_jit(): trainer.run(n_epochs, n_epochs_is_total_epochs=True) trainer.close() else: trainer.run(n_epochs, n_epochs_is_total_epochs=True) trainer.close() return # TODO(pkozakowski): Find a better way to determine this. train_env_kwargs = {} eval_env_kwargs = {} if 'OnlineTuneEnv' in env_name: envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, 'envs') train_env_output_dir = os.path.join(envs_output_dir, 'train') eval_env_output_dir = os.path.join(envs_output_dir, 'eval') train_env_kwargs = {'output_dir': train_env_output_dir} eval_env_kwargs = {'output_dir': eval_env_output_dir} parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1 logging.info('Num discretized actions %s', num_actions) logging.info('Resize %d', resize) train_env = env_problem_utils.make_env( batch_size=train_batch_size, env_problem_name=env_name, rendered_env=rendered_env, resize=resize, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=parallelism, use_tpu=FLAGS.use_tpu, num_actions=num_actions, **train_env_kwargs) assert train_env eval_env = env_problem_utils.make_env( batch_size=eval_batch_size, env_problem_name=env_name, rendered_env=rendered_env, resize=resize, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=parallelism, use_tpu=FLAGS.use_tpu, num_actions=num_actions, **eval_env_kwargs) assert eval_env def run_training_loop(): """Runs the training loop.""" logging.info('Starting the training loop.') trainer = trainer_class( output_dir=output_dir, train_env=train_env, eval_env=eval_env, trajectory_dump_dir=trajectory_dump_dir, async_mode=FLAGS.async_mode, ) trainer.training_loop(n_epochs=n_epochs) if FLAGS.jax_debug_nans or FLAGS.disable_jit: math.disable_jit() with jax.disable_jit(): run_training_loop() else: run_training_loop()
def __init__( self, output_dir, env_name='PongNoFrameskip-v4', env_kwargs=None, train_batch_size=16, eval_batch_size=16, trainer_class=ppo_trainer.PPO, action_multipliers=None, observation_metrics=( ('eval', 'eval/raw_reward_mean/temperature_1.0'), ('eval', 'eval/raw_reward_std/temperature_1.0'), ), include_controls_in_observation=False, reward_metric=('eval', 'eval/raw_reward_mean/temperature_1.0'), train_epochs=100, env_steps=100, # This is a tuple instead of a dict because the controls are # ordered in the action space. control_configs=( # (name, start, (low, high), flip) ('learning_rate', 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), # Don't save checkpoints by default, as they tend to use a lot of # space. should_save_checkpoints=False, # Same here. should_write_summaries=False, ): if action_multipliers is None: action_multipliers = self.DEFAULT_ACTION_MULTIPLIERS if env_kwargs is None: env_kwargs = {} (train_env, eval_env) = tuple( env_problem_utils.make_env( # pylint: disable=g-complex-comprehension env_problem_name=env_name, batch_size=batch_size, **env_kwargs) for batch_size in (train_batch_size, eval_batch_size)) # Initialize Trainer in OnlineTuneRLEnv lazily to prevent long startup in # the async setup, where we just use the environments as containers for # trajectories. self._trainer_fn = functools.partial( trainer_class, train_env=train_env, eval_env=eval_env, controller=(lambda history: lambda step: self._current_controls), should_save_checkpoints=should_save_checkpoints, should_write_summaries=should_write_summaries, ) self._trainer = None self._action_multipliers = action_multipliers self._observation_metrics = observation_metrics self._include_controls_in_observation = include_controls_in_observation self._reward_metric = reward_metric self._train_epochs = train_epochs self._env_steps = env_steps self._control_configs = control_configs self._observation_range = observation_range self._output_dir = output_dir gfile.makedirs(self._output_dir) # Actions are indices in self._action_multipliers. self.action_space = gym.spaces.MultiDiscrete( [len(self._action_multipliers)] * len(self._control_configs)) # Observation is a vector with the values of the metrics specified in # observation_metrics plus optionally the current controls. observation_dim = (len(self._observation_metrics) + int(self._include_controls_in_observation) * len(self._control_configs)) (obs_low, obs_high) = observation_range self.observation_space = gym.spaces.Box( # Observations are clipped to this range. low=obs_low, high=obs_high, shape=(observation_dim, ), )
def train_rl( output_dir, train_batch_size, eval_batch_size, env_name='ClientEnv-v0', max_timestep=None, clip_rewards=False, rendered_env=False, resize_dims=(105, 80), trainer_class=rl_trainers.PPO, n_epochs=10000, trajectory_dump_dir=None, ): """Train the RL agent. Args: output_dir: Output directory. train_batch_size: Number of parallel environments to use for training. eval_batch_size: Number of parallel environments to use for evaluation. env_name: Name of the environment. max_timestep: Int or None, the maximum number of timesteps in a trajectory. The environment is wrapped in a TimeLimit wrapper. clip_rewards: Whether to clip and discretize the rewards. rendered_env: Whether the environment has visual input. If so, a RenderedEnvProblem will be used. resize_dims: Pair (height, width), dimensions to resize the visual observations to. trainer_class: RLTrainer class to use. n_epochs: Number epochs to run the training for. trajectory_dump_dir: Directory to dump trajectories to. """ if FLAGS.jax_debug_nans: config.update('jax_debug_nans', True) if FLAGS.use_tpu: config.update('jax_platform_name', 'tpu') else: config.update('jax_platform_name', 'gpu') # TODO(pkozakowski): Find a better way to determine this. train_env_kwargs = {} eval_env_kwargs = {} if 'OnlineTuneEnv' in env_name: envs_output_dir = FLAGS.envs_output_dir or os.path.join(output_dir, 'envs') train_env_output_dir = os.path.join(envs_output_dir, 'train') eval_env_output_dir = os.path.join(envs_output_dir, 'eval') train_env_kwargs = {'output_dir': train_env_output_dir} eval_env_kwargs = {'output_dir': eval_env_output_dir} if 'ClientEnv' in env_name: train_env_kwargs['per_env_kwargs'] = [{ 'remote_env_address': os.path.join(FLAGS.train_server_bns, str(replica)) } for replica in range(train_batch_size)] eval_env_kwargs['per_env_kwargs'] = [{ 'remote_env_address': os.path.join(FLAGS.eval_server_bns, str(replica)) } for replica in range(eval_batch_size)] # TODO(afrozm): Should we leave out some cores? parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1 train_env = env_problem_utils.make_env( batch_size=train_batch_size, env_problem_name=env_name, resize=rendered_env, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=parallelism, use_tpu=FLAGS.use_tpu, **train_env_kwargs) assert train_env eval_env = env_problem_utils.make_env( batch_size=eval_batch_size, env_problem_name=env_name, resize=rendered_env, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=parallelism, use_tpu=FLAGS.use_tpu, **eval_env_kwargs) assert eval_env def run_training_loop(): """Runs the training loop.""" logging.info('Starting the training loop.') trainer = trainer_class( output_dir=output_dir, train_env=train_env, eval_env=eval_env, trajectory_dump_dir=trajectory_dump_dir, async_mode=FLAGS.async_mode, ) trainer.training_loop(n_epochs=n_epochs) if FLAGS.jax_debug_nans or FLAGS.disable_jit: with jax.disable_jit(): run_training_loop() else: run_training_loop()
def create_envs_and_collect_trajectories( output_dir, env_name="OnlineTuneEnv-v0", max_timestep=None, clip_rewards=False, rendered_env=False, resize_dims=(105, 80), ): """Creates the envs and continuously collects trajectories.""" train_batch_size = 1 eval_batch_size = 1 # TODO(pkozakowski): Find a better way to determine this. train_env_kwargs = {} eval_env_kwargs = {} if "OnlineTuneEnv" in env_name: # TODO(pkozakowski): Separate env output dirs by train/eval and epoch. train_env_kwargs = { "output_dir": os.path.join(output_dir, "envs/train") } eval_env_kwargs = {"output_dir": os.path.join(output_dir, "envs/eval")} if "ClientEnv" in env_name: train_env_kwargs["per_env_kwargs"] = [{ "remote_env_address": os.path.join(FLAGS.train_server_bns, str(replica)) } for replica in range(train_batch_size)] eval_env_kwargs["per_env_kwargs"] = [{ "remote_env_address": os.path.join(FLAGS.eval_server_bns, str(replica)) } for replica in range(eval_batch_size)] parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1 train_parallelism = min(train_batch_size, parallelism) eval_parallelism = min(eval_batch_size, parallelism) train_env = env_problem_utils.make_env(batch_size=train_batch_size, env_problem_name=env_name, resize=rendered_env, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=train_parallelism, use_tpu=FLAGS.use_tpu, **train_env_kwargs) assert train_env eval_env = env_problem_utils.make_env(batch_size=eval_batch_size, env_problem_name=env_name, resize=rendered_env, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=eval_parallelism, use_tpu=FLAGS.use_tpu, **eval_env_kwargs) assert eval_env def run_collect_loop(): async_lib.continuously_collect_trajectories( output_dir, train_env, eval_env, trajectory_dump_dir=None, env_id=FLAGS.replica, try_abort=FLAGS.try_abort, max_trajectories_to_collect=( None if FLAGS.max_trajectories_to_collect < 0 else FLAGS.max_trajectories_to_collect)) if FLAGS.jax_debug_nans or FLAGS.disable_jit: with jax.disable_jit(): run_collect_loop() else: run_collect_loop()
def create_envs_and_collect_trajectories( output_dir, env_name='OnlineTuneEnv-v0', max_timestep=None, clip_rewards=False, rendered_env=False, resize_dims=(105, 80), ): """Creates the envs and continuously collects trajectories.""" train_batch_size = 1 eval_batch_size = 1 # TODO(pkozakowski): Find a better way to determine this. train_env_kwargs = {} eval_env_kwargs = {} if 'OnlineTuneEnv' in env_name: envs_output_dir = FLAGS.envs_output_dir or os.path.join( output_dir, 'envs') train_env_output_dir = os.path.join(envs_output_dir, 'train') eval_env_output_dir = os.path.join(envs_output_dir, 'eval') train_env_kwargs = {'output_dir': train_env_output_dir} eval_env_kwargs = {'output_dir': eval_env_output_dir} parallelism = multiprocessing.cpu_count() if FLAGS.parallelize_envs else 1 train_parallelism = min(train_batch_size, parallelism) eval_parallelism = min(eval_batch_size, parallelism) train_env = env_problem_utils.make_env(batch_size=train_batch_size, env_problem_name=env_name, resize=rendered_env, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=train_parallelism, use_tpu=FLAGS.use_tpu, **train_env_kwargs) assert train_env eval_env = env_problem_utils.make_env(batch_size=eval_batch_size, env_problem_name=env_name, resize=rendered_env, resize_dims=resize_dims, max_timestep=max_timestep, clip_rewards=clip_rewards, parallelism=eval_parallelism, use_tpu=FLAGS.use_tpu, **eval_env_kwargs) assert eval_env def run_collect_loop(): async_lib.continuously_collect_trajectories( output_dir, train_env, eval_env, trajectory_dump_dir=None, env_id=FLAGS.replica, try_abort=FLAGS.try_abort, max_trajectories_to_collect=( None if FLAGS.max_trajectories_to_collect < 0 else FLAGS.max_trajectories_to_collect)) if FLAGS.jax_debug_nans or FLAGS.disable_jit: with jax.disable_jit(): run_collect_loop() else: run_collect_loop()