Ejemplo n.º 1
0
def init_policy_from_world_model_checkpoint(policy_params, model_output_dir):
    """Initializes policy parameters from world model parameters."""
    pkl_module = utils.get_pickle_module()
    params_file = os.path.join(model_output_dir, "model.pkl")
    # Don't use trax.restore_state to avoid a circular import.
    with gfile.GFile(params_file, "rb") as f:
        model_params = pkl_module.load(f)[0][0]
    # TODO(pkozakowski): The following, brittle line of code is hardcoded for
    # transplanting parameters from TransformerLM to TransformerDecoder-based
    # policy network of the same configuration. Figure out a more general method.
    policy_params[0] = model_params[1:-2]
    return policy_params
Ejemplo n.º 2
0
def save_opt_state(output_dir, policy_and_value_opt_state,
                   policy_and_value_state, epoch, total_opt_step):
    """Saves the policy and value network optimization state etc."""
    pkl_module = utils.get_pickle_module()
    old_model_files = get_policy_model_files(output_dir)
    params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch)
    with gfile.GFile(params_file, "wb") as f:
        pkl_module.dump((policy_and_value_opt_state, policy_and_value_state,
                         total_opt_step), f)
    # Remove the old model files.
    for path in old_model_files:
        if path != params_file:
            gfile.remove(path)
Ejemplo n.º 3
0
def save_state(state, output_dir, keep=False):
  """Save State and optionally gin config."""
  pkl_module = utils.get_pickle_module()
  params_file = os.path.join(output_dir, "model.pkl")
  with gfile.GFile(params_file, "wb") as f:
    pkl_module.dump((tuple(state.opt_state), state.step, state.history,
                     state.model_state), f)
  if keep:
    params_file = os.path.join(output_dir, "model_{}.pkl".format(state.step))
    with gfile.GFile(params_file, "wb") as f:
      pkl_module.dump((tuple(state.opt_state), state.step, state.history,
                       state.model_state), f)
  log("Model saved to %s" % params_file, stdout=False)
Ejemplo n.º 4
0
def save_opt_state(output_dir, policy_and_value_opt_state,
                   policy_and_value_state, epoch, total_opt_step):
    """Saves the policy and value network optimization state etc."""
    pkl_module = utils.get_pickle_module()
    old_model_files = get_policy_model_files(output_dir)
    params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch)
    with gfile.GFile(params_file, "wb") as f:
        pkl_module.dump((policy_and_value_opt_state, policy_and_value_state,
                         total_opt_step), f)
    # Keep the last k model files lying around (note k > 1 because the latest
    # model file might be in the process of getting read async).
    for path in old_model_files[LAST_N_POLICY_MODELS_TO_KEEP:]:
        if path != params_file:
            gfile.remove(path)
Ejemplo n.º 5
0
def save_opt_state(output_dir, policy_and_value_opt_state,
                   policy_and_value_state, epoch, total_opt_step):
    """Saves the policy and value network optimization state etc."""
    pkl_module = utils.get_pickle_module()
    old_model_files = get_policy_model_files(output_dir)
    params_file = os.path.join(output_dir, "model-%06d.pkl" % epoch)
    with gfile.GFile(params_file, "wb") as f:
        pkl_module.dump((policy_and_value_opt_state, policy_and_value_state,
                         total_opt_step), f)
    # Remove the old model files, leave the latest one (it might be in the
    # process of getting read async) -- this will get cleaned up later.
    for path in old_model_files[1:]:
        if path != params_file:
            gfile.remove(path)
Ejemplo n.º 6
0
def restore_state(output_dir):
  """Restore State."""
  params_file = os.path.join(output_dir, "model.pkl")
  if not gfile.exists(params_file):
    return State(step=None, opt_state=None, history=trax_history.History(),
                 model_state=None)

  pkl_module = utils.get_pickle_module()
  with gfile.GFile(params_file, "rb") as f:
    (opt_state, step, history, model_state) = pkl_module.load(f)
  log("Model loaded from %s at step %d" % (params_file, step))
  logging.debug("From loaded model : history = %s", history)
  return State(step=step, opt_state=OptState(*opt_state), history=history,
               model_state=model_state)
Ejemplo n.º 7
0
def load_trajectories(trajectory_dir, eval_frac):
    """Loads trajectories from a possibly nested directory of pickles."""
    pkl_module = utils.get_pickle_module()
    train_trajectories = []
    eval_trajectories = []
    # Search the entire directory subtree for trajectories.
    for (subdir, _, filenames) in gfile.walk(trajectory_dir):
        for filename in filenames:
            shard_path = os.path.join(subdir, filename)
            with gfile.GFile(shard_path, "rb") as f:
                trajectories = pkl_module.load(f)
                pivot = int(len(trajectories) * (1 - eval_frac))
                train_trajectories.extend(trajectories[:pivot])
                eval_trajectories.extend(trajectories[pivot:])
    assert train_trajectories, "Haven't found any training data."
    assert eval_trajectories, "Haven't found any evaluation data."
    return (train_trajectories, eval_trajectories)
Ejemplo n.º 8
0
    def dump_trajectories(self, force=False):
        """Dumps trajectories in a new shard.

    Should be called at most once per epoch.

    Args:
      force: (bool) Whether to complete unfinished trajectories and create
        a new shard even if we have not reached the minimum size.
    """
        pkl_module = utils.get_pickle_module()
        if self.trajectory_dump_dir is None:
            return
        gfile.makedirs(self.trajectory_dump_dir)

        trajectories = self.train_env.trajectories
        if force:
            trajectories.complete_all_trajectories()

        # complete_all_trajectories() also adds trajectories that were just reset.
        # We don't want them since they have just the initial observation and no
        # actions, so we filter them out.
        def has_any_action(trajectory):
            return (trajectory.time_steps
                    and trajectory.time_steps[0].action is not None)

        self._trajectory_buffer.extend(
            filter(has_any_action, trajectories.completed_trajectories))

        trajectories.clear_completed_trajectories()
        ready = (len(self._trajectory_buffer) >=
                 self._trajectory_dump_min_count_per_shard)
        if ready or force:
            shard_path = os.path.join(self.trajectory_dump_dir,
                                      "{}.pkl".format(self.epoch))
            if gfile.exists(shard_path):
                # Since we do an extra dump at the end of the training loop, we
                # sometimes dump 2 times in the same epoch. When this happens, merge the
                # two sets of trajectories.
                with gfile.GFile(shard_path, "rb") as f:
                    self._trajectory_buffer = pkl_module.load(
                        f) + self._trajectory_buffer
            with gfile.GFile(shard_path, "wb") as f:
                pkl_module.dump(self._trajectory_buffer, f)
            self._trajectory_buffer = []
Ejemplo n.º 9
0
def maybe_restore_opt_state(output_dir,
                            policy_and_value_opt_state=None,
                            policy_and_value_state=None):
    """Maybe restore the optimization state from the checkpoint dir.

  Optimization state includes parameters and optimizer slots.

  Args:
    output_dir: Directory where saved model checkpoints are stored.
    policy_and_value_opt_state: Default optimization state, returned if model
      isn't found.
    policy_and_value_state: state of the policy and value network.

  Returns:
    tuple (opt_state, state, epoch (int), opt_step (int)) where epoch is the
    epoch from which we restored the optimization state, 0 if no checkpoint was
    found, and opt_step is the total optimization step (sum of all optimization
    steps made up to the current epoch).
  """
    pkl_module = utils.get_pickle_module()
    epoch = 0
    total_opt_step = 0
    for model_file in get_policy_model_files(output_dir):
        logging.info("Trying to restore model from %s", model_file)
        try:
            with gfile.GFile(model_file, "rb") as f:
                policy_and_value_opt_state, policy_and_value_state, total_opt_step = (
                    pkl_module.load(f))
            epoch = get_epoch_from_policy_model_file(model_file)
            break
        except EOFError as e:
            logging.error("Unable to load model from: %s with %s", model_file,
                          e)
            # Try an older version.
            continue
    return (
        policy_and_value_opt_state,
        policy_and_value_state,
        epoch,
        total_opt_step,
    )
Ejemplo n.º 10
0
def load_trajectories(trajectory_dir, eval_frac):
  """Loads trajectories from a possibly nested directory of pickles."""
  pkl_module = utils.get_pickle_module()
  train_trajectories = []
  eval_trajectories = []
  # Search the entire directory subtree for trajectories.
  for (subdir, _, filenames) in gfile.walk(trajectory_dir):
    for filename in filenames:
      shard_path = os.path.join(subdir, filename)
      try:
        with gfile.GFile(shard_path, "rb") as f:
          trajectories = pkl_module.load(f)
        pivot = int(len(trajectories) * (1 - eval_frac))
        train_trajectories.extend(trajectories[:pivot])
        eval_trajectories.extend(trajectories[pivot:])
      except EOFError:
        logging.warning(
            "Could not load trajectories from a corrupted shard %s.",
            shard_path,
        )
  assert train_trajectories, "Can't find training data in %s" % trajectory_dir
  assert eval_trajectories, "Can't find evaluation data in %s" % trajectory_dir
  return train_trajectories, eval_trajectories
Ejemplo n.º 11
0
 def _dump_trajectory_pickle(self, observations, path):
     pkl_module = utils.get_pickle_module()
     trajectories = list(map(self._make_singleton_trajectory, observations))
     with gfile.GFile(path, "wb") as f:
         pkl_module.dump(trajectories, f)