예제 #1
0
파일: classic.py 프로젝트: wanjinchang/vel
    def checkpoint(self,
                   epoch_info: EpochInfo,
                   model: Model,
                   state_dict: dict = None):
        """ When epoch is done, we persist the training state """
        state_dict = state_dict if state_dict is not None else {}

        self.clean(epoch_info.global_epoch_idx)

        self._make_sure_dir_exists()

        # Checkpoint latest
        torch.save(model.state_dict(),
                   self.checkpoint_filename(epoch_info.global_epoch_idx))

        hidden_state = state_dict.copy()

        if epoch_info.optimizer is not None:
            hidden_state['optimizer'] = epoch_info.optimizer.state_dict()

        for callback in epoch_info.callbacks:
            callback.write_state_dict(hidden_state)

        self.checkpoint_strategy.write_state_dict(hidden_state)

        torch.save(
            hidden_state,
            self.checkpoint_hidden_filename(epoch_info.global_epoch_idx))

        if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint(
                epoch_info.global_epoch_idx):
            prev_epoch_idx = epoch_info.global_epoch_idx - 1

            os.remove(self.checkpoint_filename(prev_epoch_idx))
            os.remove(self.checkpoint_hidden_filename(prev_epoch_idx))

        if self.checkpoint_strategy.should_store_best_checkpoint(
                epoch_info.global_epoch_idx, epoch_info.result):
            best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx

            if best_checkpoint_idx is not None:
                os.remove(self.checkpoint_best_filename(best_checkpoint_idx))

            torch.save(
                model.state_dict(),
                self.checkpoint_best_filename(epoch_info.global_epoch_idx))

            self.checkpoint_strategy.store_best_checkpoint_idx(
                epoch_info.global_epoch_idx)

        self.backend.store(epoch_info.result)
예제 #2
0
    def checkpoint(self, epoch_info: EpochInfo, model: Model):
        """ When epoch is done, we persist the training state """
        self.clean(epoch_info.global_epoch_idx - 1)

        self._make_sure_dir_exists()

        # Checkpoint latest
        torch.save(model.state_dict(),
                   self.checkpoint_filename(epoch_info.global_epoch_idx))

        hidden_state = epoch_info.state_dict()
        self.checkpoint_strategy.write_state_dict(hidden_state)

        torch.save(
            hidden_state,
            self.checkpoint_hidden_filename(epoch_info.global_epoch_idx))

        if epoch_info.global_epoch_idx > 1 and self.checkpoint_strategy.should_delete_previous_checkpoint(
                epoch_info.global_epoch_idx):
            prev_epoch_idx = epoch_info.global_epoch_idx - 1

            os.remove(self.checkpoint_filename(prev_epoch_idx))
            os.remove(self.checkpoint_hidden_filename(prev_epoch_idx))

        if self.checkpoint_strategy.should_store_best_checkpoint(
                epoch_info.global_epoch_idx, epoch_info.result):
            best_checkpoint_idx = self.checkpoint_strategy.current_best_checkpoint_idx

            if best_checkpoint_idx is not None:
                os.remove(self.checkpoint_best_filename(best_checkpoint_idx))

            torch.save(
                model.state_dict(),
                self.checkpoint_best_filename(epoch_info.global_epoch_idx))

            self.checkpoint_strategy.store_best_checkpoint_idx(
                epoch_info.global_epoch_idx)

        self.backend.store(epoch_info.result)