def init_from_file(self, file_name): """Initialize this task from file.""" dictionary = training.unpickle_from_file(file_name, gzip=False) self._n_trajectories = dictionary['n_trajectories'] self._n_interactions = dictionary['n_interactions'] self._max_steps = dictionary['max_steps'] self._gamma = dictionary['gamma'] epochs_to_load = dictionary['all_epochs'][-self._n_replay_epochs:] for epoch in epochs_to_load: trajectories = training.unpickle_from_file( self._epoch_filename(file_name, epoch), gzip=True) self._trajectories[epoch] = trajectories self._saved_epochs_unchanged = epochs_to_load
def load_initial_trajectories_from_path(self, initial_trajectories_path, dictionary_file='trajectories.pkl', start_epoch_to_load=0): """Initialize trajectories task from file.""" # We assume that this is a dump generated by Trax dictionary_file = os.path.join(initial_trajectories_path, dictionary_file) dictionary = training.unpickle_from_file(dictionary_file, gzip=False) # TODO(henrykm): as currently implemented this accesses only # at most the last n_replay_epochs - this should be more flexible epochs_to_load = dictionary['all_epochs'][start_epoch_to_load:] all_trajectories = [] for epoch in epochs_to_load: trajectories = training.unpickle_from_file( self._epoch_filename(dictionary_file, epoch), gzip=True) all_trajectories += trajectories return all_trajectories
def load_trainer_state(output_dir, model, weights_file=None): """Returns a TrainerState instance loaded from the given `output_dir`.""" if weights_file is None: weights_file = os.path.join(output_dir, 'model.pkl.gz') if not tf.io.gfile.exists(weights_file): return TrainerState(step=None, opt_state=None, history=trax_history.History(), model_state=None) elif not tf.io.gfile.exists(weights_file): raise ValueError('File not found: %s' % weights_file) trainer_state_dict = training.unpickle_from_file(weights_file, gzip=True) trainer_state = trainer_state_from_dict(trainer_state_dict, model) log('Model loaded from %s at step %d' % (weights_file, trainer_state.step)) logging.debug('From loaded model : history = %s', trainer_state.history) return trainer_state