Beispiel #1
0
  def init_from_file(self, file_name):
    """Initialize this task from file."""
    dictionary = training.unpickle_from_file(file_name, gzip=False)
    self._n_trajectories = dictionary['n_trajectories']
    self._n_interactions = dictionary['n_interactions']
    self._max_steps = dictionary['max_steps']
    self._gamma = dictionary['gamma']
    epochs_to_load = dictionary['all_epochs'][-self._n_replay_epochs:]

    for epoch in epochs_to_load:
      trajectories = training.unpickle_from_file(
          self._epoch_filename(file_name, epoch), gzip=True)
      self._trajectories[epoch] = trajectories
    self._saved_epochs_unchanged = epochs_to_load
Beispiel #2
0
  def load_initial_trajectories_from_path(self,
                                          initial_trajectories_path,
                                          dictionary_file='trajectories.pkl',
                                          start_epoch_to_load=0):
    """Initialize trajectories task from file."""
    # We assume that this is a dump generated by Trax
    dictionary_file = os.path.join(initial_trajectories_path, dictionary_file)
    dictionary = training.unpickle_from_file(dictionary_file, gzip=False)
    # TODO(henrykm): as currently implemented this accesses only
    # at most the last n_replay_epochs - this should be more flexible
    epochs_to_load = dictionary['all_epochs'][start_epoch_to_load:]

    all_trajectories = []
    for epoch in epochs_to_load:
      trajectories = training.unpickle_from_file(
          self._epoch_filename(dictionary_file, epoch), gzip=True)
      all_trajectories += trajectories
    return all_trajectories
Beispiel #3
0
def load_trainer_state(output_dir, model, weights_file=None):
  """Returns a TrainerState instance loaded from the given `output_dir`."""
  if weights_file is None:
    weights_file = os.path.join(output_dir, 'model.pkl.gz')
    if not tf.io.gfile.exists(weights_file):
      return TrainerState(step=None, opt_state=None,
                          history=trax_history.History(), model_state=None)
  elif not tf.io.gfile.exists(weights_file):
    raise ValueError('File not found: %s' % weights_file)

  trainer_state_dict = training.unpickle_from_file(weights_file, gzip=True)
  trainer_state = trainer_state_from_dict(trainer_state_dict, model)
  log('Model loaded from %s at step %d' % (weights_file, trainer_state.step))
  logging.debug('From loaded model : history = %s', trainer_state.history)
  return trainer_state