def init_from_file(self, file_name): """Initialize this task from file.""" dictionary = trainer_lib.unpickle_from_file(file_name, gzip=False) self._max_steps = dictionary['max_steps'] self._gamma = dictionary['gamma'] epochs_to_load = dictionary['all_epochs'] for epoch in epochs_to_load: trajectories = trainer_lib.unpickle_from_file( self._epoch_filename(file_name, epoch), gzip=True) self._trajectories[epoch] = trajectories self._saved_epochs_unchanged = epochs_to_load
def load_initial_trajectories_from_path(self, initial_trajectories_path, dictionary_file='trajectories.pkl', start_epoch_to_load=0): """Initialize trajectories task from file.""" # We assume that this is a dump generated by Trax dictionary_file = os.path.join(initial_trajectories_path, dictionary_file) dictionary = trainer_lib.unpickle_from_file(dictionary_file, gzip=False) # TODO(henrykm): as currently implemented this accesses only # at most the last n_replay_epochs - this should be more flexible epochs_to_load = dictionary['all_epochs'][start_epoch_to_load:] all_trajectories = [] for epoch in epochs_to_load: trajectories = trainer_lib.unpickle_from_file(self._epoch_filename( dictionary_file, epoch), gzip=True) all_trajectories += trajectories return all_trajectories