def test_continue_from_checkpoint(self): def on_node(n1, n2): self.assertTrue(np.array_equal(n1.numpy(), n2.numpy())) model = SimpleModel().train() dp = DataProcessor(model=model) with self.assertRaises(Model.ModelException): dp.save_state() try: fsm = FileStructManager(self.base_dir, is_continue=False) dp.set_checkpoints_manager(CheckpointsManager(fsm)) dp.save_state() except: self.fail('Fail to DataProcessor load when CheckpointsManager was defined') del dp model_new = SimpleModel().train() dp = DataProcessor(model=model_new) with self.assertRaises(Model.ModelException): dp.load() fsm = FileStructManager(base_dir=self.base_dir, is_continue=True) dp.set_checkpoints_manager(CheckpointsManager(fsm)) try: fsm = FileStructManager(self.base_dir, is_continue=True) dp.set_checkpoints_manager(CheckpointsManager(fsm)) dp.load() except: self.fail('Fail to DataProcessor load when CheckpointsManager was defined') compare_two_models(self, model, model_new)
def _save_state(self, ckpts_manager: CheckpointsManager, best_ckpts_manager: CheckpointsManager or None, cur_best_state: float or None, epoch_idx: int) -> float or None: """ Internal method used for save states after epoch end :param ckpts_manager: ordinal checkpoints manager :param best_ckpts_manager: checkpoints manager, used for store best stages :param cur_best_state: current best stage metric value :return: new best stage metric value or None if it not update """ def save_trainer(ckp_manager): with open(ckp_manager.trainer_file(), 'w') as out: json.dump({'last_epoch': epoch_idx}, out) if self._best_state_rule is not None: new_best_state = self._best_state_rule() if cur_best_state is None: self._data_processor.save_state() save_trainer(ckpts_manager) ckpts_manager.pack() return new_best_state else: if new_best_state <= cur_best_state: self._data_processor.set_checkpoints_manager(best_ckpts_manager) self._data_processor.save_state() save_trainer(best_ckpts_manager) best_ckpts_manager.pack() self._data_processor.set_checkpoints_manager(ckpts_manager) return new_best_state self._data_processor.save_state() save_trainer(ckpts_manager) ckpts_manager.pack() return None
def test_initialisation(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) try: cm = CheckpointsManager(fsm) except Exception as err: self.fail("Fail init CheckpointsManager; err: ['{}']".format(err)) with self.assertRaises(FileStructManager.FSMException): CheckpointsManager(fsm) os.mkdir(os.path.join(fsm.get_path(cm), 'test_dir')) with self.assertRaises(FileStructManager.FSMException): CheckpointsManager(fsm)
def test_continue_from_checkpoint(self): def on_node(n1, n2): self.assertTrue(np.array_equal(n1.numpy(), n2.numpy())) model = SimpleModel().train() loss = SimpleLoss() for optim in [ torch.optim.SGD(model.parameters(), lr=0.1), torch.optim.Adam(model.parameters(), lr=0.1) ]: train_config = TrainConfig([], loss, optim) dp_before = TrainDataProcessor(model=model, train_config=train_config) before_state_dict = model.state_dict().copy() dp_before.update_lr(0.023) with self.assertRaises(Model.ModelException): dp_before.save_state() try: fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) dp_before.set_checkpoints_manager(CheckpointsManager(fsm)) dp_before.save_state() except: self.fail( "Exception on saving state when 'CheckpointsManager' specified" ) fsm = FileStructManager(base_dir=self.base_dir, is_continue=True) dp_after = TrainDataProcessor(model=model, train_config=train_config) with self.assertRaises(Model.ModelException): dp_after.load() try: cm = CheckpointsManager(fsm) dp_after.set_checkpoints_manager(cm) dp_after.load() except: self.fail('DataProcessor initialisation raises exception') after_state_dict = model.state_dict().copy() dict_pair_recursive_bypass(before_state_dict, after_state_dict, on_node) self.assertEqual(dp_before.get_lr(), dp_after.get_lr()) shutil.rmtree(self.base_dir)
def test_clear_files(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) cm = CheckpointsManager(fsm) f = open(cm.weights_file(), 'w') f.close() f = open(cm.optimizer_state_file(), 'w') f.close() cm.clear_files() for f in [cm.weights_file(), cm.optimizer_state_file()]: if os.path.exists(f) and os.path.isfile(f): self.fail("File '{}' doesn't remove after pack".format(f))
def __init__(self, train_config: TrainConfig, fsm: FileStructManager, device: torch.device = None): self._fsm = fsm self.monitor_hub = MonitorHub() self._checkpoint_manager = CheckpointsManager(self._fsm) self.__epoch_num = 100 self._resume_from = None self._on_epoch_end = [] self._best_state_rule = None self._train_config = train_config self._data_processor = TrainDataProcessor(self._train_config, device).set_checkpoints_manager(self._checkpoint_manager) self._lr = LearningRate(self._data_processor.get_lr()) self._stop_rules = []
def __init__(self, model: Model, fsm: FileStructManager, device: torch.device = None): self._fsm = fsm self.__data_processor = DataProcessor(model, device=device) checkpoint_manager = CheckpointsManager(self._fsm) self.__data_processor.set_checkpoints_manager(checkpoint_manager) checkpoint_manager.unpack() self.__data_processor.load() checkpoint_manager.pack()
def _resume(self) -> int: if self._resume_from == 'last': ckpts_manager = self._checkpoint_manager elif self._checkpoint_manager == 'best': ckpts_manager = CheckpointsManager(self._fsm, 'best') else: raise NotImplementedError("Resume parameter may be only 'last' or 'best' not {}".format(self._resume_from)) ckpts_manager.unpack() self._data_processor.load() with open(ckpts_manager.trainer_file(), 'r') as file: start_epoch_idx = json.load(file)['last_epoch'] + 1 ckpts_manager.pack() return start_epoch_idx
def __init__(self, model: Model, fsm: FileStructManager, from_best_state: bool = False): self._fsm = fsm self._data_processor = DataProcessor(model) checkpoint_manager = CheckpointsManager( self._fsm, prefix='best' if from_best_state else None) self._data_processor.set_checkpoints_manager(checkpoint_manager) checkpoint_manager.unpack() self._data_processor.load() checkpoint_manager.pack()
def train(self) -> None: """ Run training process """ if len(self.__train_config.stages()) < 1: raise self.TrainerException("There's no sages for training") best_checkpoints_manager = None cur_best_state = None if self._best_state_rule is not None: best_checkpoints_manager = CheckpointsManager(self._fsm, 'best') start_epoch_idx = 1 if self._resume_from is not None: start_epoch_idx += self._resume() self.monitor_hub.add_monitor(ConsoleMonitor()) with self.monitor_hub: for epoch_idx in range(start_epoch_idx, self.__epoch_num + start_epoch_idx): self.monitor_hub.set_epoch_num(epoch_idx) for stage in self.__train_config.stages(): stage.run(self._data_processor) if stage.metrics_processor() is not None: self.monitor_hub.update_metrics( stage.metrics_processor().get_metrics()) new_best_state = self._save_state(self._checkpoint_manager, best_checkpoints_manager, cur_best_state, epoch_idx) if new_best_state is not None: cur_best_state = new_best_state self._data_processor.update_lr(self._lr.value()) for clbk in self._on_epoch_end: clbk() self._update_losses() self.__iterate_by_stages(lambda s: s.on_epoch_end())
def test_unpack(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) cm = CheckpointsManager(fsm) f = open(cm.weights_file(), 'w') f.write('1') f.close() f = open(cm.optimizer_state_file(), 'w') f.write('2') f = open(cm.trainer_file(), 'w') f.write('3') f.close() cm.pack() try: cm.unpack() except Exception as err: self.fail('Exception on unpacking') for i, f in enumerate( [cm.weights_file(), cm.optimizer_state_file(), cm.trainer_file()]): if not (os.path.exists(f) and os.path.isfile(f)): self.fail("File '{}' doesn't remove after pack".format(f)) with open(f, 'r') as file: if file.read() != str(i + 1): self.fail("File content corrupted")
def test_pack(self): fsm = FileStructManager(base_dir=self.base_dir, is_continue=False) cm = CheckpointsManager(fsm) with self.assertRaises(CheckpointsManager.SMException): cm.pack() os.mkdir(cm.weights_file()) os.mkdir(cm.optimizer_state_file()) with self.assertRaises(CheckpointsManager.SMException): cm.pack() shutil.rmtree(cm.weights_file()) shutil.rmtree(cm.optimizer_state_file()) f = open(cm.weights_file(), 'w') f.close() f = open(cm.optimizer_state_file(), 'w') f.close() f = open(cm.trainer_file(), 'w') f.close() try: cm.pack() except Exception as err: self.fail('Exception on packing files: [{}]'.format(err)) for f in [cm.weights_file(), cm.optimizer_state_file()]: if os.path.exists(f) and os.path.isfile(f): self.fail("File '{}' doesn't remove after pack".format(f)) result = os.path.join( fsm.get_path(cm, check=False, create_if_non_exists=False), 'last_checkpoint.zip') self.assertTrue(os.path.exists(result) and os.path.isfile(result)) f = open(cm.weights_file(), 'w') f.close() f = open(cm.optimizer_state_file(), 'w') f.close() f = open(cm.trainer_file(), 'w') f.close() try: cm.pack() result = os.path.join( fsm.get_path(cm, check=False, create_if_non_exists=False), 'last_checkpoint.zip.old') self.assertTrue(os.path.exists(result) and os.path.isfile(result)) except Exception as err: self.fail('Fail to pack with existing previous state file')