Пример #1
0
    def test_continue_from_checkpoint(self):
        def on_node(n1, n2):
            self.assertTrue(np.array_equal(n1.numpy(), n2.numpy()))

        model = SimpleModel().train()
        dp = DataProcessor(model=model)
        with self.assertRaises(Model.ModelException):
            dp.save_state()
        try:
            fsm = FileStructManager(self.base_dir, is_continue=False)
            dp.set_checkpoints_manager(CheckpointsManager(fsm))
            dp.save_state()
        except:
            self.fail('Fail to DataProcessor load when CheckpointsManager was defined')

        del dp

        model_new = SimpleModel().train()
        dp = DataProcessor(model=model_new)

        with self.assertRaises(Model.ModelException):
            dp.load()
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
        dp.set_checkpoints_manager(CheckpointsManager(fsm))
        try:
            fsm = FileStructManager(self.base_dir, is_continue=True)
            dp.set_checkpoints_manager(CheckpointsManager(fsm))
            dp.load()
        except:
            self.fail('Fail to DataProcessor load when CheckpointsManager was defined')

        compare_two_models(self, model, model_new)
Пример #2
0
    def _save_state(self, ckpts_manager: CheckpointsManager, best_ckpts_manager: CheckpointsManager or None,
                    cur_best_state: float or None, epoch_idx: int) -> float or None:
        """
        Internal method used for save states after epoch end

        :param ckpts_manager: ordinal checkpoints manager
        :param best_ckpts_manager: checkpoints manager, used for store best stages
        :param cur_best_state: current best stage metric value
        :return: new best stage metric value or None if it not update
        """
        def save_trainer(ckp_manager):
            with open(ckp_manager.trainer_file(), 'w') as out:
                json.dump({'last_epoch': epoch_idx}, out)

        if self._best_state_rule is not None:
            new_best_state = self._best_state_rule()
            if cur_best_state is None:
                self._data_processor.save_state()
                save_trainer(ckpts_manager)
                ckpts_manager.pack()
                return new_best_state
            else:
                if new_best_state <= cur_best_state:
                    self._data_processor.set_checkpoints_manager(best_ckpts_manager)
                    self._data_processor.save_state()
                    save_trainer(best_ckpts_manager)
                    best_ckpts_manager.pack()
                    self._data_processor.set_checkpoints_manager(ckpts_manager)
                    return new_best_state

        self._data_processor.save_state()
        save_trainer(ckpts_manager)
        ckpts_manager.pack()
        return None
Пример #3
0
    def test_initialisation(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)

        try:
            cm = CheckpointsManager(fsm)
        except Exception as err:
            self.fail("Fail init CheckpointsManager; err: ['{}']".format(err))

        with self.assertRaises(FileStructManager.FSMException):
            CheckpointsManager(fsm)

        os.mkdir(os.path.join(fsm.get_path(cm), 'test_dir'))
        with self.assertRaises(FileStructManager.FSMException):
            CheckpointsManager(fsm)
    def test_continue_from_checkpoint(self):
        def on_node(n1, n2):
            self.assertTrue(np.array_equal(n1.numpy(), n2.numpy()))

        model = SimpleModel().train()
        loss = SimpleLoss()

        for optim in [
                torch.optim.SGD(model.parameters(), lr=0.1),
                torch.optim.Adam(model.parameters(), lr=0.1)
        ]:
            train_config = TrainConfig([], loss, optim)

            dp_before = TrainDataProcessor(model=model,
                                           train_config=train_config)
            before_state_dict = model.state_dict().copy()
            dp_before.update_lr(0.023)

            with self.assertRaises(Model.ModelException):
                dp_before.save_state()
            try:
                fsm = FileStructManager(base_dir=self.base_dir,
                                        is_continue=False)
                dp_before.set_checkpoints_manager(CheckpointsManager(fsm))
                dp_before.save_state()
            except:
                self.fail(
                    "Exception on saving state when 'CheckpointsManager' specified"
                )

            fsm = FileStructManager(base_dir=self.base_dir, is_continue=True)
            dp_after = TrainDataProcessor(model=model,
                                          train_config=train_config)
            with self.assertRaises(Model.ModelException):
                dp_after.load()
            try:
                cm = CheckpointsManager(fsm)
                dp_after.set_checkpoints_manager(cm)
                dp_after.load()
            except:
                self.fail('DataProcessor initialisation raises exception')

            after_state_dict = model.state_dict().copy()

            dict_pair_recursive_bypass(before_state_dict, after_state_dict,
                                       on_node)
            self.assertEqual(dp_before.get_lr(), dp_after.get_lr())

            shutil.rmtree(self.base_dir)
Пример #5
0
    def test_clear_files(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()

        cm.clear_files()

        for f in [cm.weights_file(), cm.optimizer_state_file()]:
            if os.path.exists(f) and os.path.isfile(f):
                self.fail("File '{}' doesn't remove after pack".format(f))
Пример #6
0
    def __init__(self, train_config: TrainConfig, fsm: FileStructManager, device: torch.device = None):
        self._fsm = fsm
        self.monitor_hub = MonitorHub()

        self._checkpoint_manager = CheckpointsManager(self._fsm)

        self.__epoch_num = 100
        self._resume_from = None
        self._on_epoch_end = []
        self._best_state_rule = None

        self._train_config = train_config
        self._data_processor = TrainDataProcessor(self._train_config, device).set_checkpoints_manager(self._checkpoint_manager)
        self._lr = LearningRate(self._data_processor.get_lr())

        self._stop_rules = []
Пример #7
0
 def __init__(self,
              model: Model,
              fsm: FileStructManager,
              device: torch.device = None):
     self._fsm = fsm
     self.__data_processor = DataProcessor(model, device=device)
     checkpoint_manager = CheckpointsManager(self._fsm)
     self.__data_processor.set_checkpoints_manager(checkpoint_manager)
     checkpoint_manager.unpack()
     self.__data_processor.load()
     checkpoint_manager.pack()
Пример #8
0
    def _resume(self) -> int:
        if self._resume_from == 'last':
            ckpts_manager = self._checkpoint_manager
        elif self._checkpoint_manager == 'best':
            ckpts_manager = CheckpointsManager(self._fsm, 'best')
        else:
            raise NotImplementedError("Resume parameter may be only 'last' or 'best' not {}".format(self._resume_from))
        ckpts_manager.unpack()
        self._data_processor.load()

        with open(ckpts_manager.trainer_file(), 'r') as file:
            start_epoch_idx = json.load(file)['last_epoch'] + 1

        ckpts_manager.pack()
        return start_epoch_idx
Пример #9
0
 def __init__(self,
              model: Model,
              fsm: FileStructManager,
              from_best_state: bool = False):
     self._fsm = fsm
     self._data_processor = DataProcessor(model)
     checkpoint_manager = CheckpointsManager(
         self._fsm, prefix='best' if from_best_state else None)
     self._data_processor.set_checkpoints_manager(checkpoint_manager)
     checkpoint_manager.unpack()
     self._data_processor.load()
     checkpoint_manager.pack()
Пример #10
0
    def train(self) -> None:
        """
        Run training process
        """
        if len(self.__train_config.stages()) < 1:
            raise self.TrainerException("There's no sages for training")

        best_checkpoints_manager = None
        cur_best_state = None
        if self._best_state_rule is not None:
            best_checkpoints_manager = CheckpointsManager(self._fsm, 'best')

        start_epoch_idx = 1
        if self._resume_from is not None:
            start_epoch_idx += self._resume()
        self.monitor_hub.add_monitor(ConsoleMonitor())

        with self.monitor_hub:
            for epoch_idx in range(start_epoch_idx,
                                   self.__epoch_num + start_epoch_idx):
                self.monitor_hub.set_epoch_num(epoch_idx)
                for stage in self.__train_config.stages():
                    stage.run(self._data_processor)

                    if stage.metrics_processor() is not None:
                        self.monitor_hub.update_metrics(
                            stage.metrics_processor().get_metrics())

                new_best_state = self._save_state(self._checkpoint_manager,
                                                  best_checkpoints_manager,
                                                  cur_best_state, epoch_idx)
                if new_best_state is not None:
                    cur_best_state = new_best_state

                self._data_processor.update_lr(self._lr.value())

                for clbk in self._on_epoch_end:
                    clbk()

                self._update_losses()
                self.__iterate_by_stages(lambda s: s.on_epoch_end())
Пример #11
0
    def test_unpack(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)

        f = open(cm.weights_file(), 'w')
        f.write('1')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.write('2')
        f = open(cm.trainer_file(), 'w')
        f.write('3')
        f.close()

        cm.pack()

        try:
            cm.unpack()
        except Exception as err:
            self.fail('Exception on unpacking')

        for i, f in enumerate(
            [cm.weights_file(),
             cm.optimizer_state_file(),
             cm.trainer_file()]):
            if not (os.path.exists(f) and os.path.isfile(f)):
                self.fail("File '{}' doesn't remove after pack".format(f))
            with open(f, 'r') as file:
                if file.read() != str(i + 1):
                    self.fail("File content corrupted")
Пример #12
0
    def test_pack(self):
        fsm = FileStructManager(base_dir=self.base_dir, is_continue=False)
        cm = CheckpointsManager(fsm)
        with self.assertRaises(CheckpointsManager.SMException):
            cm.pack()

        os.mkdir(cm.weights_file())
        os.mkdir(cm.optimizer_state_file())
        with self.assertRaises(CheckpointsManager.SMException):
            cm.pack()

        shutil.rmtree(cm.weights_file())
        shutil.rmtree(cm.optimizer_state_file())

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()
        f = open(cm.trainer_file(), 'w')
        f.close()

        try:
            cm.pack()
        except Exception as err:
            self.fail('Exception on packing files: [{}]'.format(err))

        for f in [cm.weights_file(), cm.optimizer_state_file()]:
            if os.path.exists(f) and os.path.isfile(f):
                self.fail("File '{}' doesn't remove after pack".format(f))

        result = os.path.join(
            fsm.get_path(cm, check=False, create_if_non_exists=False),
            'last_checkpoint.zip')
        self.assertTrue(os.path.exists(result) and os.path.isfile(result))

        f = open(cm.weights_file(), 'w')
        f.close()
        f = open(cm.optimizer_state_file(), 'w')
        f.close()
        f = open(cm.trainer_file(), 'w')
        f.close()

        try:
            cm.pack()
            result = os.path.join(
                fsm.get_path(cm, check=False, create_if_non_exists=False),
                'last_checkpoint.zip.old')
            self.assertTrue(os.path.exists(result) and os.path.isfile(result))
        except Exception as err:
            self.fail('Fail to pack with existing previous state file')