Ejemplo n.º 1
0
    def _open_summary_writers(self):
        """Opens the Jaxboard summary writers wrapped by context manager.

    Yields:
      Tuple (train_summary_writer, eval_summary_writer) of Jaxboard summary
      writers wrapped by the GeneratorContextManager object.
      If there was no output_dir provided, yields (None, None).
    """
        if self._output_dir is not None:
            _log('Training and evaluation metrics will be written in %s.' %
                 self._output_dir,
                 stdout=False)
            train_summary_writer = jaxboard.SummaryWriter(
                os.path.join(self._output_dir, 'train'))
            eval_summary_writer = jaxboard.SummaryWriter(
                os.path.join(self._output_dir, 'eval'))
            try:
                yield train_summary_writer, eval_summary_writer
            finally:
                train_summary_writer.close()
                eval_summary_writer.close()
                _log('Training and evaluation metrics were written in %s.' %
                     self._output_dir,
                     stdout=False)
        else:
            yield None, None
Ejemplo n.º 2
0
  def reset(self, output_dir, init_checkpoint=None):
    """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
      init_checkpoint: Initial checkpoint to use (default $output_dir/model.pkl)
    """
    self.close()
    self._output_dir = output_dir
    if output_dir is not None:
      tf.io.gfile.makedirs(output_dir)
    else:
      assert not self._should_save_checkpoints
      assert not self._should_write_summaries

    # Create summary writers and history.
    if self._should_write_summaries:
      self._train_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'train'),
                                              enable=self._is_chief)
      self._eval_sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'eval'),
                                             enable=self._is_chief)

    # Reset the train and eval streams.
    self._train_stream = _repeat_stream(self._inputs.train_stream,
                                        self._n_devices)
    # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
    #   set by adding a padding and stopping the stream when too large.
    self._eval_stream = _repeat_stream(
        self._inputs.eval_stream, self._n_devices)
    self._train_eval_stream = _repeat_stream(
        self._inputs.train_eval_stream, self._n_devices)

    # Restore the training state.
    if output_dir is not None:
      state = load_trainer_state(output_dir, init_checkpoint)
    else:
      state = TrainerState(step=None, opt_state=None,
                           history=trax_history.History(), model_state=None)
    self._step = state.step or 0
    history = state.history
    self._lr_fn = self._lr_schedule(history)
    self._history = history
    if state.opt_state:
      opt_state = state.opt_state
      model_state = state.model_state
    else:
      opt_state, model_state = self._new_opt_state_and_model_state()
      model_state = self._for_n_devices(model_state)
    self._opt_state = OptState(*self._for_n_devices(opt_state))
    self._model_state = model_state
    if not state.opt_state and self._should_save_checkpoints:
      self.save_state(keep=False)

    self.update_nontrainable_params()
Ejemplo n.º 3
0
  def reset(self, output_dir=None):
    super(PolicyBasedTrainer, self).reset(output_dir)

    # Create summary writers and history.
    if self._should_write_summaries:
      self._train_sw = jaxboard.SummaryWriter(
          os.path.join(self._output_dir, 'train'))
      self._timing_sw = jaxboard.SummaryWriter(
          os.path.join(self._output_dir, 'timing'))
      self._eval_sw = jaxboard.SummaryWriter(
          os.path.join(self._output_dir, 'eval'))

    # Try to initialize from a saved checkpoint, or initialize from scratch if
    # there is no saved checkpoint.
    self.update_optimization_state(output_dir)

    # If uninitialized, i.e. _policy_and_value_opt_state is None, then
    # initialize.
    if self._policy_and_value_opt_state is None:
      (policy_and_value_net, _) = self._policy_and_value_net_fn()
      obs_space = self.train_env.observation_space
      act_space = self.train_env.action_space
      input_signature = (
          ShapeDtype(
              (1, self._max_timestep + 1) + obs_space.shape, obs_space.dtype
          ),
          ShapeDtype(
              (1, self._max_timestep) + act_space.shape, act_space.dtype
          ),
      )
      weights, self._model_state = policy_and_value_net.init(
          input_signature, rng=self._get_rng()
      )

      # Initialize the optimizer.
      self._init_state_from_weights(weights)

    # If we need to initialize from the world model, do that here.
    if self.init_policy_from_world_model_output_dir is not None:
      weights = policy_based_utils.init_policy_from_world_model_checkpoint(
          self._policy_and_value_net_weights,
          self.init_policy_from_world_model_output_dir,
          self._substitute_fn,
      )
      # Initialize the optimizer.
      self._init_state_from_weights(weights)

    self._n_trajectories_done_since_last_save = 0
    self._last_saved_at_epoch = self.epoch

    if self._async_mode:
      logging.info('Saving model on startup to have a model policy file.')
      self.save()
Ejemplo n.º 4
0
    def reset(self, output_dir):
        """Reset the model parameters.

    Restores the parameters from the given output_dir if a checkpoint exists,
    otherwise randomly initializes them.

    Does not re-jit the model.

    Args:
      output_dir: Output directory.
    """
        self._output_dir = output_dir
        gfile.makedirs(output_dir)
        # Create summary writers and history.
        if self._should_write_summaries:
            self._train_sw = jaxboard.SummaryWriter(os.path.join(
                output_dir, 'train'),
                                                    enable=self.is_chief)
            self._eval_sw = jaxboard.SummaryWriter(os.path.join(
                output_dir, 'eval'),
                                                   enable=self.is_chief)

        # Reset the train and eval streams.
        self._train_stream = self._inputs.train_stream()
        # TODO(lukaszkaiser): add an option to evaluate exactly on the full eval
        #   set by adding a padding and stopping the stream when too large.
        self._eval_stream = _repeat_stream(self._inputs.eval_stream)
        self._train_eval_stream = _repeat_stream(
            self._inputs.train_eval_stream)

        # Restore the training state.
        state = load_trainer_state(output_dir)
        self._step = state.step or 0
        history = state.history
        self._lr_fn = self._lr_schedule(history)
        self._history = history
        if state.opt_state:
            opt_state = state.opt_state
            model_state = state.model_state
        else:
            opt_state, model_state = self._new_opt_state_and_model_state()
            model_state = layers.nested_map(self._maybe_replicate, model_state)
        self._opt_state = OptState(
            *layers.nested_map(self._maybe_replicate, opt_state))
        self._model_state = model_state
        if not state.opt_state and self.is_chief:
            self._maybe_save_state(keep=False)

        self.update_nontrainable_params()
Ejemplo n.º 5
0
    def __init__(self,
                 task: rl_task.RLTask,
                 collect_per_epoch=None,
                 output_dir=None,
                 timestep_to_np=None):
        """Configures the RL Trainer.

    Note that subclasses can have many more arguments, which will be configured
    using defaults and gin. But task and output_dir are passed explicitly.

    Args:
      task: RLTask instance, which defines the environment to train on.
      collect_per_epoch: How many new trajectories to collect in each epoch.
      output_dir: Path telling where to save outputs such as checkpoints.
      timestep_to_np: Timestep-to-numpy function to override in the task.
    """
        self._epoch = 0
        self._task = task
        if timestep_to_np is not None:
            self._task.timestep_to_np = timestep_to_np
        self._collect_per_epoch = collect_per_epoch
        self._output_dir = output_dir
        self._avg_returns = []
        self._sw = None
        if output_dir is not None:
            self._sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'rl'))
Ejemplo n.º 6
0
    def on_step_end(self, step):
        summary_writer = jaxboard.SummaryWriter(
            os.path.join(self._loop.output_dir, 'srl_eval'))
        try:
            self._model.weights = serialization_utils.extract_inner_model(
                self._loop.eval_model.weights)

            metrics = collections.defaultdict(list)
            for _ in range(self._n_steps):
                batch = self._eval_task.next_batch()
                step_metrics = self._eval_batch(batch)
                for (key, value) in step_metrics.items():
                    metrics[key].append(value)

            def metric_name(context, horizon):
                return f'pred_error/context_{context}/horizon_{horizon}'

            metrics = {
                metric_name(context, horizon):
                np.sum(errors) / np.sum(errors != 0)
                for ((context, horizon), errors) in metrics.items()
            }
            self._loop.log_summary(metrics, summary_writer, '', 'srl_eval')
        finally:
            summary_writer.close()
Ejemplo n.º 7
0
    def _open_summary_writers(self, subdirs):
        """Opens the Jaxboard summary writers wrapped by context manager.

    Args:
      subdirs: List of names of subdirectories to open summary writers for.

    Yields:
      Tuple (writers_1, ..., writers_n) of tuples of Jaxboard summary
      writers wrapped in a GeneratorContextManager object. Elements of the outer
      tuple correspond to subdirs. Elements of the inner tuples correspond to
      tasks. If there was no output_dir provided, yields the same nested tuple
      of None writers.
    """
        if self._output_dir is not None:
            _log(
                'Metrics will be written in {}.'.format(self._output_dir),
                stdout=False,
            )
            writer_per_subdir_and_task = tuple(
                tuple(  # pylint: disable=g-complex-comprehension
                    jaxboard.SummaryWriter(os.path.join(output_dir, subdir))
                    for output_dir in self._output_dir_per_task)
                for subdir in subdirs)
            try:
                yield writer_per_subdir_and_task
            finally:
                for writer_per_task in writer_per_subdir_and_task:
                    for writer in writer_per_task:
                        writer.close()
                _log('Metrics were written in {}.'.format(self._output_dir),
                     stdout=False)
        else:
            yield ((None, ) * len(self._tasks), ) * len(subdirs)
Ejemplo n.º 8
0
  def reset(self, output_dir):
    super(PPO, self).reset(output_dir)

    # Initialize the policy and value network.
    # Create the network again to avoid caching of parameters.
    policy_and_value_net = self._policy_and_value_net_fn()
    (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype(
        self.train_env.observation_space
    )
    self._rng, _ = jax_random.split(self._rng)
    input_signature = ShapeDtype(batch_obs_shape, obs_dtype)
    policy_and_value_net_params, self._model_state = (
        policy_and_value_net.init(input_signature))
    if self.init_policy_from_world_model_output_dir is not None:
      policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint(
          policy_and_value_net_params,
          self.init_policy_from_world_model_output_dir,
      )

    # Initialize the optimizer.
    (init_slots, init_opt_params) = (
        self._policy_and_value_optimizer.tree_init(policy_and_value_net_params)
    )
    self._policy_and_value_opt_state = (
        policy_and_value_net_params, init_slots, init_opt_params
    )

    # Restore the optimizer state.
    self._epoch = 0
    self._total_opt_step = 0
    self.update_optimization_state(output_dir)

    # Create summary writers and history.
    if self._should_write_summaries:
      self._train_sw = jaxboard.SummaryWriter(
          os.path.join(self._output_dir, 'train'))
      self._timing_sw = jaxboard.SummaryWriter(
          os.path.join(self._output_dir, 'timing'))
      self._eval_sw = jaxboard.SummaryWriter(
          os.path.join(self._output_dir, 'eval'))

    self._n_trajectories_done = 0

    self._last_saved_at = 0
    if self._async_mode:
      logging.info('Saving model on startup to have a model policy file.')
      self.save()
Ejemplo n.º 9
0
 def on_step_end(self, step):
     summary_writer = jaxboard.SummaryWriter(
         os.path.join(self._loop.output_dir, 'srl_eval'))
     try:
         weights = self._loop.eval_model.seq_model_weights
         metrics = self.evaluate(weights)
         self._loop.log_summary(metrics, summary_writer, '', 'srl_eval')
     finally:
         summary_writer.close()
Ejemplo n.º 10
0
 def on_step_end(self, step):
     summary_writer = jaxboard.SummaryWriter(
         os.path.join(self._loop.output_dir, 'srl_eval'))
     try:
         weights = serialization_utils.extract_inner_model(
             self._loop.eval_model.weights)
         metrics = self.evaluate(weights)
         self._loop.log_summary(metrics, summary_writer, '', 'srl_eval')
     finally:
         summary_writer.close()
Ejemplo n.º 11
0
    def reset(self, output_dir=None):
        super(AwrTrainer, self).reset(output_dir=output_dir)

        if self._should_write_summaries:
            self._opt_sw = jaxboard.SummaryWriter(
                os.path.join(self._output_dir, 'opt'))

        # Reset the replay buffer.
        self._replay_buffer.clear()
        self._replay_buffer.buffers = None

        # TODO(afrozm): Ensure that this is updated.
        self._n_observations_seen = 0
Ejemplo n.º 12
0
    def __init__(self,
                 task: rl_task.RLTask,
                 n_trajectories_per_epoch=None,
                 n_interactions_per_epoch=None,
                 n_eval_episodes=0,
                 eval_steps=None,
                 only_eval=False,
                 output_dir=None,
                 timestep_to_np=None):
        """Configures the RL Trainer.

    Note that subclasses can have many more arguments, which will be configured
    using defaults and gin. But task and output_dir are passed explicitly.

    Args:
      task: RLTask instance, which defines the environment to train on.
      n_trajectories_per_epoch: How many new trajectories to collect in each
        epoch.
      n_interactions_per_epoch: How many interactions to collect in each epoch.
      n_eval_episodes: Number of episodes to play with policy at
        temperature 0 in each epoch -- used for evaluation only.
      eval_steps: an optional list of max_steps to use for evaluation
        (defaults to task.max_steps).
      only_eval: If set to True, then trajectories are collected only for
        for evaluation purposes, but they are not recorded.
      output_dir: Path telling where to save outputs such as checkpoints.
      timestep_to_np: Timestep-to-numpy function to override in the task.
    """
        assert bool(n_trajectories_per_epoch) != bool(
            n_interactions_per_epoch
        ), ('Exactly one of n_trajectories_per_epoch or n_interactions_per_epoch '
            'should be specified.')
        self._epoch = 0
        self._task = task
        self._eval_steps = eval_steps or [task.max_steps]
        if timestep_to_np is not None:
            self._task.timestep_to_np = timestep_to_np
        self._n_trajectories_per_epoch = n_trajectories_per_epoch
        self._n_interactions_per_epoch = n_interactions_per_epoch
        self._only_eval = only_eval
        self._output_dir = output_dir
        self._avg_returns = []
        self._n_eval_episodes = n_eval_episodes
        self._avg_returns_temperature0 = {
            step: []
            for step in self._eval_steps
        }
        self._sw = None
        if output_dir is not None:
            self._sw = jaxboard.SummaryWriter(os.path.join(output_dir, 'rl'))
Ejemplo n.º 13
0
  def _open_summary_writer(self):
    """Opens the Jaxboard summary writer wrapped by a context manager.

    Yields:
      A Jaxboard summary writer wrapped in a GeneratorContextManager object.
      Elements of the lists correspond to the training and evaluation task
      directories created during initialization. If there is no output_dir
      provided, yields None.
    """
    if self._output_dir is not None:
      writer = jaxboard.SummaryWriter(os.path.join(self._output_dir, 'rl'))
      try:
        yield writer
      finally:
        writer.close()
    else:
      yield None
Ejemplo n.º 14
0
  def reset(self, output_dir=None):
    super(AwrTrainer, self).reset(output_dir=output_dir)

    if self._should_write_summaries:
      self._opt_sw = jaxboard.SummaryWriter(
          os.path.join(self._output_dir, 'opt'))

    # Reset the replay buffer.
    self._replay_buffer.clear()
    self._replay_buffer.buffers = None
    self._replay_buffer.init_buffers(self._observation_shape,
                                     self._observation_dtype,
                                     self._action_shape, self._action_dtype)
    logging.info(
        'Initialized ReplayBuffer with: obs shape [%s], obs dtype [%s], action shape [%s], action dtype [%s]',
        self._observation_shape, self._observation_dtype, self._action_shape,
        self._action_dtype)

    # TODO(afrozm): Ensure that this is updated.
    self._n_observations_seen = 0
Ejemplo n.º 15
0
    def __init__(self,
                 train_env,
                 eval_env,
                 output_dir,
                 policy_and_value_model=trax_models.FrameStackMLP,
                 policy_and_value_optimizer=functools.partial(
                     trax_opt.Adam, learning_rate=1e-3),
                 policy_and_value_two_towers=False,
                 policy_and_value_vocab_size=None,
                 n_optimizer_steps=N_OPTIMIZER_STEPS,
                 optimizer_batch_size=64,
                 print_every_optimizer_steps=PRINT_EVERY_OPTIMIZER_STEP,
                 target_kl=0.01,
                 boundary=20,
                 max_timestep=100,
                 max_timestep_eval=20000,
                 random_seed=None,
                 gamma=GAMMA,
                 lambda_=LAMBDA,
                 c1=1.0,
                 c2=0.01,
                 eval_every_n=1000,
                 save_every_n=1000,
                 done_frac_for_policy_save=0.5,
                 n_evals=1,
                 len_history_for_policy=4,
                 eval_temperatures=(1.0, 0.5),
                 separate_eval=True,
                 init_policy_from_world_model_output_dir=None,
                 **kwargs):
        """Creates the PPO trainer.

    Args:
      train_env: gym.Env to use for training.
      eval_env: gym.Env to use for evaluation.
      output_dir: Output dir.
      policy_and_value_model: Function defining the policy and value network,
        without the policy and value heads.
      policy_and_value_optimizer: Function defining the optimizer.
      policy_and_value_two_towers: Whether to use two separate models as the
        policy and value networks. If False, share their parameters.
      policy_and_value_vocab_size: Vocabulary size of a policy and value network
        operating on serialized representation. If None, use raw continuous
        representation.
      n_optimizer_steps: Number of optimizer steps.
      optimizer_batch_size: Batch size of an optimizer step.
      print_every_optimizer_steps: How often to log during the policy
        optimization process.
      target_kl: Policy iteration early stopping. Set to infinity to disable
        early stopping.
      boundary: We pad trajectories at integer multiples of this number.
      max_timestep: If set to an integer, maximum number of time-steps in a
        trajectory. Used in the collect procedure.
      max_timestep_eval: If set to an integer, maximum number of time-steps in
        an evaluation trajectory. Used in the collect procedure.
      random_seed: Random seed.
      gamma: Reward discount factor.
      lambda_: N-step TD-error discount factor in GAE.
      c1: Value loss coefficient.
      c2: Entropy loss coefficient.
      eval_every_n: How frequently to eval the policy.
      save_every_n: How frequently to save the policy.
      done_frac_for_policy_save: Fraction of the trajectories that should be
        done to checkpoint the policy.
      n_evals: Number of times to evaluate.
      len_history_for_policy: How much of history to give to the policy.
      eval_temperatures: Sequence of temperatures to try for categorical
        sampling during evaluation.
      separate_eval: Whether to run separate evaluation using a set of
        temperatures. If False, the training reward is reported as evaluation
        reward with temperature 1.0.
      init_policy_from_world_model_output_dir: Model output dir for initializing
        the policy. If None, initialize randomly.
      **kwargs: Additional keyword arguments passed to the base class.
    """
        # Set in base class constructor.
        self._train_env = None
        self._should_reset = None

        super(PPO, self).__init__(train_env, eval_env, output_dir, **kwargs)

        self._n_optimizer_steps = n_optimizer_steps
        self._optimizer_batch_size = optimizer_batch_size
        self._print_every_optimizer_steps = print_every_optimizer_steps
        self._target_kl = target_kl
        self._boundary = boundary
        self._max_timestep = max_timestep
        self._max_timestep_eval = max_timestep_eval
        self._gamma = gamma
        self._lambda_ = lambda_
        self._c1 = c1
        self._c2 = c2
        self._eval_every_n = eval_every_n
        self._save_every_n = save_every_n
        self._done_frac_for_policy_save = done_frac_for_policy_save
        self._n_evals = n_evals
        self._len_history_for_policy = len_history_for_policy
        self._eval_temperatures = eval_temperatures
        self._separate_eval = separate_eval

        action_space = self.train_env.action_space
        assert isinstance(action_space,
                          (gym.spaces.Discrete, gym.spaces.MultiDiscrete))
        if isinstance(action_space, gym.spaces.Discrete):
            n_actions = action_space.n
            n_controls = 1
        else:
            (n_controls, ) = action_space.nvec.shape
            assert n_controls > 0
            assert onp.min(action_space.nvec) == onp.max(action_space.nvec), (
                "Every control must have the same number of actions.")
            n_actions = action_space.nvec[0]
        self._n_actions = n_actions
        self._n_controls = n_controls

        self._rng = trainer_lib.get_random_number_generator_and_set_seed(
            random_seed)
        self._rng, key1 = jax_random.split(self._rng, num=2)

        vocab_size = policy_and_value_vocab_size
        self._serialized_sequence_policy = vocab_size is not None
        if self._serialized_sequence_policy:
            self._serialization_kwargs = self._init_serialization(vocab_size)
        else:
            self._serialization_kwargs = {}

        # Initialize the policy and value network.
        policy_and_value_net = ppo.policy_and_value_net(
            n_actions=n_actions,
            n_controls=n_controls,
            vocab_size=vocab_size,
            bottom_layers_fn=policy_and_value_model,
            two_towers=policy_and_value_two_towers,
        )
        self._policy_and_value_net_apply = jit(policy_and_value_net)
        (batch_obs_shape, obs_dtype) = self._batch_obs_shape_and_dtype
        policy_and_value_net_params, self._model_state = (
            policy_and_value_net.initialize_once(batch_obs_shape, obs_dtype,
                                                 key1))
        if init_policy_from_world_model_output_dir is not None:
            policy_and_value_net_params = ppo.init_policy_from_world_model_checkpoint(
                policy_and_value_net_params,
                init_policy_from_world_model_output_dir)

        # Initialize the optimizer.
        (policy_and_value_opt_state, self._policy_and_value_opt_update,
         self._policy_and_value_get_params) = ppo.optimizer_fn(
             policy_and_value_optimizer, policy_and_value_net_params)

        # Restore the optimizer state.
        self._policy_and_value_opt_state = policy_and_value_opt_state
        self._epoch = 0
        self._total_opt_step = 0
        self.update_optimization_state(
            output_dir, policy_and_value_opt_state=policy_and_value_opt_state)

        # Create summary writers and history.
        self._train_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "train"))
        self._timing_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "timing"))
        self._eval_sw = jaxboard.SummaryWriter(
            os.path.join(self._output_dir, "eval"))

        self._n_trajectories_done = 0

        self._last_saved_at = 0
        if self._async_mode:
            logging.info(
                "Saving model on startup to have a model policy file.")
            self.save()

        self._rewards_to_actions = self._init_rewards_to_actions()
Ejemplo n.º 16
0
 def _open_summary_writers(self):
     if self._should_write_summaries:
         _log('Evaluation metrics will be written in %s' % self._output_dir,
              stdout=False)
         self._eval_sw = jaxboard.SummaryWriter(
             os.path.join(self._output_dir, 'eval'))