Beispiel #1
0
    def save(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        scheduler: optim.lr_scheduler._LRScheduler,
        epoch: int,
        metric: float,
    ):
        if self.best_metric < metric:
            self.best_metric = metric
            self.best_epoch = epoch
            is_best = True
        else:
            is_best = False

        os.makedirs(self.root_dir, exist_ok=True)
        torch.save(
            {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "epoch": epoch,
                "best_epoch": self.best_epoch,
                "best_metric": self.best_metric,
            },
            osp.join(self.root_dir, f"{epoch:02d}.pth"),
        )

        if is_best:
            shutil.copy(
                osp.join(self.root_dir, f"{epoch:02d}.pth"),
                osp.join(self.root_dir, "best.pth"),
            )
Beispiel #2
0
def log_checkpoints(
    checkpoint_dir: Path,
    model: Union[nn.Module, nn.DataParallel],
    optimizer: Optimizer,
    scheduler: optim.lr_scheduler._LRScheduler,
    epoch: int,
) -> None:
    """
    Serialize a PyTorch model in the `checkpoint_dir`.

    Args:
        checkpoint_dir: the directory to store checkpoints
        model: the model to serialize
        optimizer: the optimizer to be saved
        scheduler: the LR scheduler to be saved
        epoch: the epoch number
    """
    checkpoint_file = 'checkpoint_{}.pt'.format(epoch)
    checkpoint_dir.mkdir(exist_ok=True, parents=True)
    file_path = checkpoint_dir / checkpoint_file

    if isinstance(model, nn.DataParallel):
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()

    torch.save(  # type: ignore
        {
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        },
        file_path,
    )
Beispiel #3
0
def _save(mdl: nn.Module, optimizer: optim.Optimizer,
          scheduler: optim.lr_scheduler._LRScheduler, global_counter: int,
          t0: float, loss: float, current_lr: float, ckpt_loc: str) -> str:
    """Saving checkpoint to file

    Args:
        mdl (nn.Module):
            The randomly initialized model
        optimizer (optim.Optimizer):
            The optimizer
        scheduler (optim.lr_scheduler._LRScheduler):
            The scheduler for learning rate
        global_counter (int):
            The global counter for training
        t0 (float):
            The time training was started
        loss (float):
            The loss of the model
        current_lr (float):
            The current learning rate
        ckpt_loc (str):
            Location to store model checkpoints

    Return:
        str:
            The message string
    """
    # Save status
    torch.save(mdl.state_dict(), os.path.join(ckpt_loc, 'mdl.ckpt'))
    torch.save(optimizer.state_dict(), os.path.join(ckpt_loc,
                                                    'optimizer.ckpt'))
    torch.save(scheduler.state_dict(), os.path.join(ckpt_loc,
                                                    'scheduler.ckpt'))

    message_str = (f'{global_counter}\t'
                   f'{float(time.time() - t0) / 60}\t'
                   f'{loss}\t'
                   f'{current_lr}\n')
    return message_str