Exemple #1
0
    def test_unpack(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        f = open(cm.weights_file(), 'w')
        f.write('1')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.write('2')
        f = open(cm.trainer_file(), 'w')
        f.write('3')
        f.close()

        cm.pack()

        try:
            cm.unpack()
        except Exception as err:
            self.fail('Exception on unpacking')

        for i, f in enumerate(
            [cm.weights_file(),
             cm.optimizer_state_file(),
             cm.trainer_file()]):
            if not (os.path.exists(f) and os.path.isfile(f)):
                self.fail("File '{}' doesn't remove after pack".format(f))
            with open(f, 'r') as file:
                if file.read() != str(i + 1):
                    self.fail("File content corrupted")
 def __init__(self,
              model: Model,
              fsm: FileStructManager,
              device: torch.device = None):
     self._fsm = fsm
     self.__data_processor = DataProcessor(model, device=device)
     checkpoint_manager = CheckpointsManager(self._fsm)
     self.__data_processor.set_checkpoints_manager(checkpoint_manager)
     checkpoint_manager.unpack()
     self.__data_processor.load()
     checkpoint_manager.pack()
Exemple #3
0
 def __init__(self,
              model: Model,
              fsm: FileStructManager,
              from_best_state: bool = False):
     self._fsm = fsm
     self._data_processor = DataProcessor(model)
     checkpoint_manager = CheckpointsManager(
         self._fsm, prefix='best' if from_best_state else None)
     self._data_processor.set_checkpoints_manager(checkpoint_manager)
     checkpoint_manager.unpack()
     self._data_processor.load()
     checkpoint_manager.pack()
Exemple #4
0
    def _resume(self) -> int:
        if self._resume_from == 'last':
            ckpts_manager = self._checkpoint_manager
        elif self._checkpoint_manager == 'best':
            ckpts_manager = CheckpointsManager(self._fsm, 'best')
        else:
            raise NotImplementedError("Resume parameter may be only 'last' or 'best' not {}".format(self._resume_from))
        ckpts_manager.unpack()
        self._data_processor.load()

        with open(ckpts_manager.trainer_file(), 'r') as file:
            start_epoch_idx = json.load(file)['last_epoch'] + 1

        ckpts_manager.pack()
        return start_epoch_idx