예제 #1
0
파일: task.py 프로젝트: rouniuyizu/trax
 def init_from_file(self, file_name):
   """Initialize this task from file."""
   dictionary = trainer_lib.unpickle_from_file(file_name, gzip=False)
   self._max_steps = dictionary['max_steps']
   self._gamma = dictionary['gamma']
   epochs_to_load = dictionary['all_epochs']
   for epoch in epochs_to_load:
     trajectories = trainer_lib.unpickle_from_file(
         self._epoch_filename(file_name, epoch), gzip=True)
     self._trajectories[epoch] = trajectories
   self._saved_epochs_unchanged = epochs_to_load
예제 #2
0
    def load_initial_trajectories_from_path(self,
                                            initial_trajectories_path,
                                            dictionary_file='trajectories.pkl',
                                            start_epoch_to_load=0):
        """Initialize trajectories task from file."""
        # We assume that this is a dump generated by Trax
        dictionary_file = os.path.join(initial_trajectories_path,
                                       dictionary_file)
        dictionary = trainer_lib.unpickle_from_file(dictionary_file,
                                                    gzip=False)
        # TODO(henrykm): as currently implemented this accesses only
        # at most the last n_replay_epochs - this should be more flexible
        epochs_to_load = dictionary['all_epochs'][start_epoch_to_load:]

        all_trajectories = []
        for epoch in epochs_to_load:
            trajectories = trainer_lib.unpickle_from_file(self._epoch_filename(
                dictionary_file, epoch),
                                                          gzip=True)
            all_trajectories += trajectories
        return all_trajectories