def save( self, model: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, epoch: int, metric: float, ): if self.best_metric < metric: self.best_metric = metric self.best_epoch = epoch is_best = True else: is_best = False os.makedirs(self.root_dir, exist_ok=True) torch.save( { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "epoch": epoch, "best_epoch": self.best_epoch, "best_metric": self.best_metric, }, osp.join(self.root_dir, f"{epoch:02d}.pth"), ) if is_best: shutil.copy( osp.join(self.root_dir, f"{epoch:02d}.pth"), osp.join(self.root_dir, "best.pth"), )
def log_checkpoints( checkpoint_dir: Path, model: Union[nn.Module, nn.DataParallel], optimizer: Optimizer, scheduler: optim.lr_scheduler._LRScheduler, epoch: int, ) -> None: """ Serialize a PyTorch model in the `checkpoint_dir`. Args: checkpoint_dir: the directory to store checkpoints model: the model to serialize optimizer: the optimizer to be saved scheduler: the LR scheduler to be saved epoch: the epoch number """ checkpoint_file = 'checkpoint_{}.pt'.format(epoch) checkpoint_dir.mkdir(exist_ok=True, parents=True) file_path = checkpoint_dir / checkpoint_file if isinstance(model, nn.DataParallel): model_state_dict = model.module.state_dict() else: model_state_dict = model.state_dict() torch.save( # type: ignore { 'epoch': epoch, 'model_state_dict': model_state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, file_path, )
def _save(mdl: nn.Module, optimizer: optim.Optimizer, scheduler: optim.lr_scheduler._LRScheduler, global_counter: int, t0: float, loss: float, current_lr: float, ckpt_loc: str) -> str: """Saving checkpoint to file Args: mdl (nn.Module): The randomly initialized model optimizer (optim.Optimizer): The optimizer scheduler (optim.lr_scheduler._LRScheduler): The scheduler for learning rate global_counter (int): The global counter for training t0 (float): The time training was started loss (float): The loss of the model current_lr (float): The current learning rate ckpt_loc (str): Location to store model checkpoints Return: str: The message string """ # Save status torch.save(mdl.state_dict(), os.path.join(ckpt_loc, 'mdl.ckpt')) torch.save(optimizer.state_dict(), os.path.join(ckpt_loc, 'optimizer.ckpt')) torch.save(scheduler.state_dict(), os.path.join(ckpt_loc, 'scheduler.ckpt')) message_str = (f'{global_counter}\t' f'{float(time.time() - t0) / 60}\t' f'{loss}\t' f'{current_lr}\n') return message_str