Exemple #1
0
def _load_checkpoint(*, filename, state: State):
    if os.path.isfile(filename):
        print(f"=> loading checkpoint {filename}")
        checkpoint = utils.load_checkpoint(filename)

        if not state.stage_name.startswith("infer"):
            state.stage_name = checkpoint["stage_name"]
            state.epoch = checkpoint["epoch"]
            state.global_epoch = checkpoint["global_epoch"]
            # @TODO: should we also load,
            # checkpoint_data, main_metric, minimize_metric, valid_loader ?
            # epoch_metrics, valid_metrics ?

        utils.unpack_checkpoint(checkpoint,
                                model=state.model,
                                criterion=state.criterion,
                                optimizer=state.optimizer,
                                scheduler=state.scheduler)

        print(f"loaded checkpoint {filename} "
              f"(global epoch {checkpoint['global_epoch']}, "
              f"epoch {checkpoint['epoch']}, "
              f"stage {checkpoint['stage_name']})")
    else:
        raise Exception(f"No checkpoint found at {filename}")
Exemple #2
0
def _load_checkpoint(*,
                     filename,
                     state: State,
                     load_full: bool = True) -> None:
    """
    Load checkpoint from a file.

    Arguments:
        filename (str): path to checkpoint
        state (State): training state
        load_full (bool): if true (default) then will be performed
            loading states for criterion, optimizer and scheduler.
            File should contain keys required for
            loading model (``'model_state_dict'``),
            criterion (``'criterion_state_dict'``) (only for full load),
            optimizer (``'optimizer_state_dict'``),
            scheduler (``'scheduler_state_dict'``).

    Raises:
        FileNotFoundError: when file specified in ``filename``
            is not exist.
    """
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"No checkpoint found at {filename}!")

    print(f"=> Loading checkpoint {filename}")
    checkpoint = utils.load_checkpoint(filename)

    if not state.stage_name.startswith("infer") and load_full:
        state.stage_name = checkpoint["stage_name"]
        state.epoch = checkpoint["epoch"]
        state.global_epoch = checkpoint["global_epoch"]
        # @TODO: should we also load,
        # checkpoint_data, main_metric, minimize_metric, valid_loader ?
        # epoch_metrics, valid_metrics ?

    if load_full:
        utils.unpack_checkpoint(
            checkpoint,
            model=state.model,
            criterion=state.criterion,
            optimizer=state.optimizer,
            scheduler=state.scheduler,
        )

        print(f"loaded state checkpoint {filename} "
              f"(global epoch {checkpoint['global_epoch']}, "
              f"epoch {checkpoint['epoch']}, "
              f"stage {checkpoint['stage_name']})")
    else:
        utils.unpack_checkpoint(
            checkpoint,
            model=state.model,
        )

        print(f"loaded model checkpoint {filename}")