def _open_summary_writers(self): """Opens the Jaxboard summary writers wrapped by context manager. Yields: Tuple (train_summary_writer, eval_summary_writer) of Jaxboard summary writers wrapped by the GeneratorContextManager object. If there was no output_dir provided, yields (None, None). """ if self._output_dir is not None: _log('Training and evaluation metrics will be written in %s.' % self._output_dir, stdout=False) train_summary_writer = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'train')) eval_summary_writer = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'eval')) try: yield train_summary_writer, eval_summary_writer finally: train_summary_writer.close() eval_summary_writer.close() _log('Training and evaluation metrics were written in %s.' % self._output_dir, stdout=False) else: yield None, None
def reset(self, output_dir, init_checkpoint=None): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. init_checkpoint: Initial checkpoint to use (default $output_dir/model.pkl) """ self.close() self._output_dir = output_dir if output_dir is not None: tf.io.gfile.makedirs(output_dir) else: assert not self._should_save_checkpoints assert not self._should_write_summaries # Create summary writers and history. if self._should_write_summaries: self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'), enable=self._is_chief) self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'), enable=self._is_chief) # Reset the train and eval streams. self._train_stream = _repeat_stream(self._inputs.train_stream, self._n_devices) # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval # set by adding a padding and stopping the stream when too large. self._eval_stream = _repeat_stream( self._inputs.eval_stream, self._n_devices) self._train_eval_stream = _repeat_stream( self._inputs.train_eval_stream, self._n_devices) # Restore the training state. if output_dir is not None: state = load_trainer_state(output_dir, init_checkpoint) else: state = TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) self._step = state.step or 0 history = state.history self._lr_fn = self._lr_schedule(history) self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._new_opt_state_and_model_state() model_state = self._for_n_devices(model_state) self._opt_state = OptState(*self._for_n_devices(opt_state)) self._model_state = model_state if not state.opt_state and self._should_save_checkpoints: self.save_state(keep=False) self.update_nontrainable_params()
def reset(self, output_dir=None): super(PolicyBasedTrainer, self).reset(output_dir) # Create summary writers and history. if self._should_write_summaries: self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'train')) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'timing')) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'eval')) # Try to initialize from a saved checkpoint, or initialize from scratch if # there is no saved checkpoint. self.update_optimization_state(output_dir) # If uninitialized, i.e. _policy_and_value_opt_state is None, then # initialize. if self._policy_and_value_opt_state is None: (policy_and_value_net, _) = self._policy_and_value_net_fn() obs_space = self.train_env.observation_space act_space = self.train_env.action_space input_signature = ( ShapeDtype( (1, self._max_timestep + 1) + obs_space.shape, obs_space.dtype ), ShapeDtype( (1, self._max_timestep) + act_space.shape, act_space.dtype ), ) weights, self._model_state = policy_and_value_net.init( input_signature, rng=self._get_rng() ) # Initialize the optimizer. self._init_state_from_weights(weights) # If we need to initialize from the world model, do that here. if self.init_policy_from_world_model_output_dir is not None: weights = policy_based_utils.init_policy_from_world_model_checkpoint( self._policy_and_value_net_weights, self.init_policy_from_world_model_output_dir, self._substitute_fn, ) # Initialize the optimizer. self._init_state_from_weights(weights) self._n_trajectories_done_since_last_save = 0 self._last_saved_at_epoch = self.epoch if self._async_mode: logging.info('Saving model on startup to have a model policy file.') self.save()
def reset(self, output_dir): """Reset the model parameters. Restores the parameters from the given output_dir if a checkpoint exists, otherwise randomly initializes them. Does not re-jit the model. Args: output_dir: Output directory. """ self._output_dir = output_dir gfile.makedirs(output_dir) # Create summary writers and history. if self._should_write_summaries: self._train_sw = jaxboard.SummaryWriter(os.path.join( output_dir, 'train'), enable=self.is_chief) self._eval_sw = jaxboard.SummaryWriter(os.path.join( output_dir, 'eval'), enable=self.is_chief) # Reset the train and eval streams. self._train_stream = self._inputs.train_stream() # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval # set by adding a padding and stopping the stream when too large. self._eval_stream = _repeat_stream(self._inputs.eval_stream) self._train_eval_stream = _repeat_stream( self._inputs.train_eval_stream) # Restore the training state. state = load_trainer_state(output_dir) self._step = state.step or 0 history = state.history self._lr_fn = self._lr_schedule(history) self._history = history if state.opt_state: opt_state = state.opt_state model_state = state.model_state else: opt_state, model_state = self._new_opt_state_and_model_state() model_state = layers.nested_map(self._maybe_replicate, model_state) self._opt_state = OptState( *layers.nested_map(self._maybe_replicate, opt_state)) self._model_state = model_state if not state.opt_state and self.is_chief: self._maybe_save_state(keep=False) self.update_nontrainable_params()
def __init__(self, task: rl_task.RLTask, collect_per_epoch=None, output_dir=None, timestep_to_np=None): """Configures the RL Trainer. Note that subclasses can have many more arguments, which will be configured using defaults and gin. But task and output_dir are passed explicitly. Args: task: RLTask instance, which defines the environment to train on. collect_per_epoch: How many new trajectories to collect in each epoch. output_dir: Path telling where to save outputs such as checkpoints. timestep_to_np: Timestep-to-numpy function to override in the task. """ self._epoch = 0 self._task = task if timestep_to_np is not None: self._task.timestep_to_np = timestep_to_np self._collect_per_epoch = collect_per_epoch self._output_dir = output_dir self._avg_returns = [] self._sw = None if output_dir is not None: self._sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'rl'))
def on_step_end(self, step): summary_writer = jaxboard.SummaryWriter( os.path.join(self._loop.output_dir, 'srl_eval')) try: self._model.weights = serialization_utils.extract_inner_model( self._loop.eval_model.weights) metrics = collections.defaultdict(list) for _ in range(self._n_steps): batch = self._eval_task.next_batch() step_metrics = self._eval_batch(batch) for (key, value) in step_metrics.items(): metrics[key].append(value) def metric_name(context, horizon): return f'pred_error/context_{context}/horizon_{horizon}' metrics = { metric_name(context, horizon): np.sum(errors) / np.sum(errors != 0) for ((context, horizon), errors) in metrics.items() } self._loop.log_summary(metrics, summary_writer, '', 'srl_eval') finally: summary_writer.close()
def _open_summary_writers(self, subdirs): """Opens the Jaxboard summary writers wrapped by context manager. Args: subdirs: List of names of subdirectories to open summary writers for. Yields: Tuple (writers_1, ..., writers_n) of tuples of Jaxboard summary writers wrapped in a GeneratorContextManager object. Elements of the outer tuple correspond to subdirs. Elements of the inner tuples correspond to tasks. If there was no output_dir provided, yields the same nested tuple of None writers. """ if self._output_dir is not None: _log( 'Metrics will be written in {}.'.format(self._output_dir), stdout=False, ) writer_per_subdir_and_task = tuple( tuple( # pylint: disable=g-complex-comprehension jaxboard.SummaryWriter(os.path.join(output_dir, subdir)) for output_dir in self._output_dir_per_task) for subdir in subdirs) try: yield writer_per_subdir_and_task finally: for writer_per_task in writer_per_subdir_and_task: for writer in writer_per_task: writer.close() _log('Metrics were written in {}.'.format(self._output_dir), stdout=False) else: yield ((None, ) * len(self._tasks), ) * len(subdirs)
def reset(self, output_dir): super(PPO, self).reset(output_dir) # Initialize the policy and value network. # Create the network again to avoid caching of parameters. policy_and_value_net = self._policy_and_value_net_fn() (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype( self.train_env.observation_space ) self._rng, _ = jax_random.split(self._rng) input_signature = ShapeDtype(batch_obs_shape, obs_dtype) policy_and_value_net_params, self._model_state = ( policy_and_value_net.init(input_signature)) if self.init_policy_from_world_model_output_dir is not None: policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint( policy_and_value_net_params, self.init_policy_from_world_model_output_dir, ) # Initialize the optimizer. (init_slots, init_opt_params) = ( self._policy_and_value_optimizer.tree_init(policy_and_value_net_params) ) self._policy_and_value_opt_state = ( policy_and_value_net_params, init_slots, init_opt_params ) # Restore the optimizer state. self._epoch = 0 self._total_opt_step = 0 self.update_optimization_state(output_dir) # Create summary writers and history. if self._should_write_summaries: self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'train')) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'timing')) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'eval')) self._n_trajectories_done = 0 self._last_saved_at = 0 if self._async_mode: logging.info('Saving model on startup to have a model policy file.') self.save()
def on_step_end(self, step): summary_writer = jaxboard.SummaryWriter( os.path.join(self._loop.output_dir, 'srl_eval')) try: weights = self._loop.eval_model.seq_model_weights metrics = self.evaluate(weights) self._loop.log_summary(metrics, summary_writer, '', 'srl_eval') finally: summary_writer.close()
def on_step_end(self, step): summary_writer = jaxboard.SummaryWriter( os.path.join(self._loop.output_dir, 'srl_eval')) try: weights = serialization_utils.extract_inner_model( self._loop.eval_model.weights) metrics = self.evaluate(weights) self._loop.log_summary(metrics, summary_writer, '', 'srl_eval') finally: summary_writer.close()
def reset(self, output_dir=None): super(AwrTrainer, self).reset(output_dir=output_dir) if self._should_write_summaries: self._opt_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'opt')) # Reset the replay buffer. self._replay_buffer.clear() self._replay_buffer.buffers = None # TODO(afrozm): Ensure that this is updated. self._n_observations_seen = 0
def __init__(self, task: rl_task.RLTask, n_trajectories_per_epoch=None, n_interactions_per_epoch=None, n_eval_episodes=0, eval_steps=None, only_eval=False, output_dir=None, timestep_to_np=None): """Configures the RL Trainer. Note that subclasses can have many more arguments, which will be configured using defaults and gin. But task and output_dir are passed explicitly. Args: task: RLTask instance, which defines the environment to train on. n_trajectories_per_epoch: How many new trajectories to collect in each epoch. n_interactions_per_epoch: How many interactions to collect in each epoch. n_eval_episodes: Number of episodes to play with policy at temperature 0 in each epoch -- used for evaluation only. eval_steps: an optional list of max_steps to use for evaluation (defaults to task.max_steps). only_eval: If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded. output_dir: Path telling where to save outputs such as checkpoints. timestep_to_np: Timestep-to-numpy function to override in the task. """ assert bool(n_trajectories_per_epoch) != bool( n_interactions_per_epoch ), ('Exactly one of n_trajectories_per_epoch or n_interactions_per_epoch ' 'should be specified.') self._epoch = 0 self._task = task self._eval_steps = eval_steps or [task.max_steps] if timestep_to_np is not None: self._task.timestep_to_np = timestep_to_np self._n_trajectories_per_epoch = n_trajectories_per_epoch self._n_interactions_per_epoch = n_interactions_per_epoch self._only_eval = only_eval self._output_dir = output_dir self._avg_returns = [] self._n_eval_episodes = n_eval_episodes self._avg_returns_temperature0 = { step: [] for step in self._eval_steps } self._sw = None if output_dir is not None: self._sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'rl'))
def _open_summary_writer(self): """Opens the Jaxboard summary writer wrapped by a context manager. Yields: A Jaxboard summary writer wrapped in a GeneratorContextManager object. Elements of the lists correspond to the training and evaluation task directories created during initialization. If there is no output_dir provided, yields None. """ if self._output_dir is not None: writer = jaxboard.SummaryWriter(os.path.join(self._output_dir, 'rl')) try: yield writer finally: writer.close() else: yield None
def reset(self, output_dir=None): super(AwrTrainer, self).reset(output_dir=output_dir) if self._should_write_summaries: self._opt_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'opt')) # Reset the replay buffer. self._replay_buffer.clear() self._replay_buffer.buffers = None self._replay_buffer.init_buffers(self._observation_shape, self._observation_dtype, self._action_shape, self._action_dtype) logging.info( 'Initialized ReplayBuffer with: obs shape [%s], obs dtype [%s], action shape [%s], action dtype [%s]', self._observation_shape, self._observation_dtype, self._action_shape, self._action_dtype) # TODO(afrozm): Ensure that this is updated. self._n_observations_seen = 0
def __init__(self, train_env, eval_env, output_dir, policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_optimizer=functools.partial( trax_opt.Adam, learning_rate=1e-3), policy_and_value_two_towers=False, policy_and_value_vocab_size=None, n_optimizer_steps=N_OPTIMIZER_STEPS, optimizer_batch_size=64, print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP, target_kl=0.01, boundary=20, max_timestep=100, max_timestep_eval=20000, random_seed=None, gamma=GAMMA, lambda_=LAMBDA, c1=1.0, c2=0.01, eval_every_n=1000, save_every_n=1000, done_frac_for_policy_save=0.5, n_evals=1, len_history_for_policy=4, eval_temperatures=(1.0, 0.5), separate_eval=True, init_policy_from_world_model_output_dir=None, **kwargs): """Creates the PPO trainer. Args: train_env: gym.Env to use for training. eval_env: gym.Env to use for evaluation. output_dir: Output dir. policy_and_value_model: Function defining the policy and value network, without the policy and value heads. policy_and_value_optimizer: Function defining the optimizer. policy_and_value_two_towers: Whether to use two separate models as the policy and value networks. If False, share their parameters. policy_and_value_vocab_size: Vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. n_optimizer_steps: Number of optimizer steps. optimizer_batch_size: Batch size of an optimizer step. print_every_optimizer_steps: How often to log during the policy optimization process. target_kl: Policy iteration early stopping. Set to infinity to disable early stopping. boundary: We pad trajectories at integer multiples of this number. max_timestep: If set to an integer, maximum number of time-steps in a trajectory. Used in the collect procedure. max_timestep_eval: If set to an integer, maximum number of time-steps in an evaluation trajectory. Used in the collect procedure. random_seed: Random seed. gamma: Reward discount factor. lambda_: N-step TD-error discount factor in GAE. c1: Value loss coefficient. c2: Entropy loss coefficient. eval_every_n: How frequently to eval the policy. save_every_n: How frequently to save the policy. done_frac_for_policy_save: Fraction of the trajectories that should be done to checkpoint the policy. n_evals: Number of times to evaluate. len_history_for_policy: How much of history to give to the policy. eval_temperatures: Sequence of temperatures to try for categorical sampling during evaluation. separate_eval: Whether to run separate evaluation using a set of temperatures. If False, the training reward is reported as evaluation reward with temperature 1.0. init_policy_from_world_model_output_dir: Model output dir for initializing the policy. If None, initialize randomly. **kwargs: Additional keyword arguments passed to the base class. """ # Set in base class constructor. self._train_env = None self._should_reset = None super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs) self._n_optimizer_steps = n_optimizer_steps self._optimizer_batch_size = optimizer_batch_size self._print_every_optimizer_steps = print_every_optimizer_steps self._target_kl = target_kl self._boundary = boundary self._max_timestep = max_timestep self._max_timestep_eval = max_timestep_eval self._gamma = gamma self._lambda_ = lambda_ self._c1 = c1 self._c2 = c2 self._eval_every_n = eval_every_n self._save_every_n = save_every_n self._done_frac_for_policy_save = done_frac_for_policy_save self._n_evals = n_evals self._len_history_for_policy = len_history_for_policy self._eval_temperatures = eval_temperatures self._separate_eval = separate_eval action_space = self.train_env.action_space assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.MultiDiscrete)) if isinstance(action_space, gym.spaces.Discrete): n_actions = action_space.n n_controls = 1 else: (n_controls, ) = action_space.nvec.shape assert n_controls > 0 assert onp.min(action_space.nvec) == onp.max(action_space.nvec), ( "Every control must have the same number of actions.") n_actions = action_space.nvec[0] self._n_actions = n_actions self._n_controls = n_controls self._rng = trainer_lib.get_random_number_generator_and_set_seed( random_seed) self._rng, key1 = jax_random.split(self._rng, num=2) vocab_size = policy_and_value_vocab_size self._serialized_sequence_policy = vocab_size is not None if self._serialized_sequence_policy: self._serialization_kwargs = self._init_serialization(vocab_size) else: self._serialization_kwargs = {} # Initialize the policy and value network. policy_and_value_net = ppo.policy_and_value_net( n_actions=n_actions, n_controls=n_controls, vocab_size=vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) self._policy_and_value_net_apply = jit(policy_and_value_net) (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype policy_and_value_net_params, self._model_state = ( policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype, key1)) if init_policy_from_world_model_output_dir is not None: policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint( policy_and_value_net_params, init_policy_from_world_model_output_dir) # Initialize the optimizer. (policy_and_value_opt_state, self._policy_and_value_opt_update, self._policy_and_value_get_params) = ppo.optimizer_fn( policy_and_value_optimizer, policy_and_value_net_params) # Restore the optimizer state. self._policy_and_value_opt_state = policy_and_value_opt_state self._epoch = 0 self._total_opt_step = 0 self.update_optimization_state( output_dir, policy_and_value_opt_state=policy_and_value_opt_state) # Create summary writers and history. self._train_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "train")) self._timing_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "timing")) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, "eval")) self._n_trajectories_done = 0 self._last_saved_at = 0 if self._async_mode: logging.info( "Saving model on startup to have a model policy file.") self.save() self._rewards_to_actions = self._init_rewards_to_actions()
def _open_summary_writers(self): if self._should_write_summaries: _log('Evaluation metrics will be written in %s' % self._output_dir, stdout=False) self._eval_sw = jaxboard.SummaryWriter( os.path.join(self._output_dir, 'eval'))