Exemplo n.º 1
0
def init_policy_from_world_model_checkpoint(policy_params, model_output_dir):
    """Initializes policy parameters from world model parameters."""
    pkl_module = utils.get_pickle_module()
    params_file = os.path.join(model_output_dir, "model.pkl")
    # Don't use trax.restore_state to avoid a circular import.
    with gfile.GFile(params_file, "rb") as f:
        model_params = pkl_module.load(f)[0][0]
    # TODO(pkozakowski): The following, brittle line of code is hardcoded for
    # transplanting parameters from TransformerLM to TransformerDecoder-based
    # policy network of the same configuration. Figure out a more general method.
    policy_params[0] = model_params[0][1:-2]
    return policy_params
Exemplo n.º 2
0
def init_policy_from_world_model_checkpoint(policy_weights, model_output_dir,
                                            substitute_fn):
    """Initializes policy parameters from world model parameters."""
    pkl_module = utils.get_pickle_module()
    weights_file = os.path.join(model_output_dir, 'model.pkl')
    # Don't use trax.load_trainer_state to avoid a circular import.
    with tf.io.gfile.GFile(weights_file, 'rb') as f:
        model_weights = pkl_module.load(f)['weights']
    model_weights = serialization_utils.extract_inner_model(model_weights)
    # TODO(pkozakowski): The following, brittle line of code is hardcoded for
    # transplanting parameters from TransformerLM to TransformerDecoder-based
    # policy network of the same configuration. Figure out a more general method.
    return substitute_fn(policy_weights, model_weights[1:-2])
def save_trainer_state(state, output_dir, keep=False):
  """Saves a TrainerState instance to the given `output_dir`."""
  pkl_module = utils.get_pickle_module()
  weights_file = os.path.join(output_dir, 'model.pkl')
  with gfile.GFile(weights_file, 'wb') as f:
    pkl_module.dump((tuple(state.opt_state), state.step, state.history,
                     state.model_state), f)
  if keep:
    weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(state.step))
    with gfile.GFile(weights_file, 'wb') as f:
      pkl_module.dump((tuple(state.opt_state), state.step, state.history,
                       state.model_state), f)
  log('Model saved to %s' % weights_file, stdout=False)
Exemplo n.º 4
0
def save_opt_state(output_dir, policy_and_value_opt_state,
                   policy_and_value_state, epoch, total_opt_step):
    """Saves the policy and value network optimization state etc."""
    pkl_module = utils.get_pickle_module()
    old_model_files = get_policy_model_files(output_dir)
    params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch)
    with gfile.GFile(params_file, "wb") as f:
        pkl_module.dump((policy_and_value_opt_state, policy_and_value_state,
                         total_opt_step), f)
    # Keep the last k model files lying around (note k > 1 because the latest
    # model file might be in the process of getting read async).
    for path in old_model_files[LAST_N_POLICY_MODELS_TO_KEEP:]:
        if path != params_file:
            gfile.remove(path)
def load_trainer_state(output_dir):
  """Returns a TrainerState instance loaded from the given `output_dir`."""
  weights_file = os.path.join(output_dir, 'model.pkl')
  if not gfile.exists(weights_file):
    return TrainerState(step=None, opt_state=None,
                        history=trax_history.History(), model_state=None)

  pkl_module = utils.get_pickle_module()
  with gfile.GFile(weights_file, 'rb') as f:
    (opt_state, step, history, model_state) = pkl_module.load(f)
  log('Model loaded from %s at step %d' % (weights_file, step))
  logging.debug('From loaded model : history = %s', history)
  return TrainerState(step=step, opt_state=OptState(*opt_state),
                      history=history, model_state=model_state)
Exemplo n.º 6
0
    def dump_trajectories(self, force=False):
        """Dumps trajectories in a new shard.

    Should be called at most once per epoch.

    Args:
      force: (bool) Whether to complete unfinished trajectories and create
        a new shard even if we have not reached the minimum size.
    """
        pkl_module = utils.get_pickle_module()
        if self.trajectory_dump_dir is None:
            return
        gfile.makedirs(self.trajectory_dump_dir)

        trajectories = self.train_env.trajectories
        if force:
            trajectories.complete_all_trajectories()

        # complete_all_trajectories() also adds trajectories that were just reset.
        # We don't want them since they have just the initial observation and no
        # actions, so we filter them out.
        def has_any_action(trajectory):
            return (trajectory.time_steps
                    and trajectory.time_steps[0].action is not None)

        self._trajectory_buffer.extend(
            filter(has_any_action, trajectories.completed_trajectories))

        trajectories.clear_completed_trajectories()
        ready = (len(self._trajectory_buffer) >=
                 self._trajectory_dump_min_count_per_shard)
        if ready or force:
            shard_path = os.path.join(self.trajectory_dump_dir,
                                      "{}.pkl".format(self.epoch))
            if gfile.exists(shard_path):
                # Since we do an extra dump at the end of the training loop, we
                # sometimes dump 2 times in the same epoch. When this happens, merge the
                # two sets of trajectories.
                with gfile.GFile(shard_path, "rb") as f:
                    self._trajectory_buffer = pkl_module.load(
                        f) + self._trajectory_buffer
            with gfile.GFile(shard_path, "wb") as f:
                pkl_module.dump(self._trajectory_buffer, f)
            self._trajectory_buffer = []
Exemplo n.º 7
0
def maybe_restore_opt_state(output_dir,
                            policy_and_value_opt_state=None,
                            policy_and_value_state=None):
    """Maybe restore the optimization state from the checkpoint dir.

  Optimization state includes parameters and optimizer slots.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_opt_state: Default optimization state, returned if model
      isn't found.
    policy_and_value_state: state of the policy and value network.

  Returns:
    tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the
    epoch from which we restored the optimization state, 0 if no checkpoint was
    found, and opt_step is the total optimization step (sum of all optimization
    steps made up to the current epoch).
  """
    pkl_module = utils.get_pickle_module()
    epoch = 0
    total_opt_step = 0
    history = trax_history.History()
    for model_file in get_policy_model_files(output_dir):
        logging.info('Trying to restore model from %s', model_file)
        try:
            with tf.io.gfile.GFile(model_file, 'rb') as f:
                (policy_and_value_opt_state, policy_and_value_state,
                 total_opt_step, history) = pkl_module.load(f)
            epoch = get_epoch_from_policy_model_file(model_file)
            break
        except EOFError as e:
            logging.error('Unable to load model from: %s with %s', model_file,
                          e)
            # Try an older version.
            continue
    return (
        policy_and_value_opt_state,
        policy_and_value_state,
        epoch,
        total_opt_step,
        history,
    )
Exemplo n.º 8
0
  def save_state(self, keep):
    """Save trainer state given a possibly replicated opt_state."""
    opt_state = self._opt_state
    if self.n_devices > 1:
      first_replica = lambda x: x[0]
      opt_state = OptState(*backend.nested_map(first_replica, opt_state))
    # This line, while optional, allows JAX to transfer arrays from the device
    # to the host in parallel, which is particularly important for cloud TPU.
    if backend.get_name() == 'jax':
      opt_state = jax.device_get(opt_state)
    step, history, model_state = self._step, self._history, self._model_state
    output_dir = self._output_dir

    pkl_module = utils.get_pickle_module()
    weights_file = os.path.join(output_dir, 'model.pkl')
    with tf.io.gfile.GFile(weights_file, 'wb') as f:
      pkl_module.dump((tuple(opt_state), step, history, model_state), f)
    if keep:
      weights_file = os.path.join(output_dir, 'model_{}.pkl'.format(step))
      with tf.io.gfile.GFile(weights_file, 'wb') as f:
        pkl_module.dump((tuple(opt_state), step, history, model_state), f)
    log('Model saved to %s' % weights_file, stdout=False)
Exemplo n.º 9
0
def load_trajectories(trajectory_dir, eval_frac):
  """Loads trajectories from a possibly nested directory of pickles."""
  pkl_module = utils.get_pickle_module()
  train_trajectories = []
  eval_trajectories = []
  # Search the entire directory subtree for trajectories.
  for (subdir, _, filenames) in tf.io.gfile.walk(trajectory_dir):
    for filename in filenames:
      shard_path = os.path.join(subdir, filename)
      try:
        with tf.io.gfile.GFile(shard_path, 'rb') as f:
          trajectories = pkl_module.load(f)
        pivot = int(len(trajectories) * (1 - eval_frac))
        train_trajectories.extend(trajectories[:pivot])
        eval_trajectories.extend(trajectories[pivot:])
      except EOFError:
        logging.warning(
            'Could not load trajectories from a corrupted shard %s.',
            shard_path,
        )
  assert train_trajectories, "Can't find training data in %s" % trajectory_dir
  assert eval_trajectories, "Can't find evaluation data in %s" % trajectory_dir
  return train_trajectories, eval_trajectories
Exemplo n.º 10
0
 def _dump_trajectory_pickle(self, observations, path):
   pkl_module = utils.get_pickle_module()
   trajectories = list(map(self._make_singleton_trajectory, observations))
   with gfile.GFile(path, 'wb') as f:
     pkl_module.dump(trajectories, f)