Пример #1
0
def _load_states_from_file_map(*, runner: IRunner,
                               load_map: Dict[str, str]) -> None:
    """
    Load state of a model, criterion, optimizer, scheduler
    from files specified in ``load_map``.

    Arguments:
        runner (IRunner): current runner
        load_map (Dict[str, str]): dict with mappings to load.
            Expected keys - ``'model'``, ``'criterion'``
            ``'optimizer'``, ``'scheduler'``, other keys will be
            ignored.
            Expected that values will be states (``'best'``,
            ``"best_full"``, ``"last"``, ``"last_full"``) or
            path to checkpoint.
            **NOTE:** for successful load criterion, optimizer,
            scheduler states required a full checkpoint.

    Raises:
        FileNotFoundError: when file/state specified in ``load_map``
            is not exist.
    """
    required_files = _required_files(runner.logdir, load_map)

    for filename in required_files.keys():
        if not os.path.isfile(filename):
            raise FileNotFoundError(f"No checkpoint found at {filename}!")

    # extracting parts from files
    for filename, parts_to_load in required_files.items():
        print(f"=> Loading {', '.join(parts_to_load)} from {filename}")
        checkpoint = utils.load_checkpoint(filename)
        to_unpack = {part: getattr(runner, part) for part in parts_to_load}
        utils.unpack_checkpoint(checkpoint, **to_unpack)
        print(f"   loaded: {', '.join(parts_to_load)}")
Пример #2
0
def _load_checkpoint(*, filename, state: State):
    if not os.path.isfile(filename):
        raise Exception(f"No checkpoint found at {filename}")

    print(f"=> loading checkpoint {filename}")
    checkpoint = utils.load_checkpoint(filename)

    if not state.stage_name.startswith("infer"):
        state.stage_name = checkpoint["stage_name"]
        state.epoch = checkpoint["epoch"]
        state.global_epoch = checkpoint["global_epoch"]
        # @TODO: should we also load,
        # checkpoint_data, main_metric, minimize_metric, valid_loader ?
        # epoch_metrics, valid_metrics ?

    utils.unpack_checkpoint(
        checkpoint,
        model=state.model,
        criterion=state.criterion,
        optimizer=state.optimizer,
        scheduler=state.scheduler,
    )

    print(
        f"loaded checkpoint {filename} "
        f"(global epoch {checkpoint['global_epoch']}, "
        f"epoch {checkpoint['epoch']}, "
        f"stage {checkpoint['stage_name']})"
    )
Пример #3
0
def _load_checkpoint(*,
                     filename,
                     runner: IRunner,
                     load_full: bool = True) -> None:
    """
    Load checkpoint from a file.

    Arguments:
        filename (str): path to checkpoint
        runner (IRunner): current runner
        load_full (bool): if true (default) then will be performed
            loading states for criterion, optimizer and scheduler.
            File should contain keys required for
            loading model (``'model_state_dict'``),
            criterion (``'criterion_state_dict'``) (only for full load),
            optimizer (``'optimizer_state_dict'``),
            scheduler (``'scheduler_state_dict'``).

    Raises:
        FileNotFoundError: when file specified in ``filename``
            is not exist.
    """
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"No checkpoint found at {filename}!")

    print(f"=> Loading checkpoint {filename}")
    checkpoint = utils.load_checkpoint(filename)

    if not runner.stage_name.startswith("infer") and load_full:
        runner.stage_name = checkpoint["stage_name"]
        runner.epoch = checkpoint["epoch"]
        runner.global_epoch = checkpoint["global_epoch"]
        # @TODO: should we also load,
        # checkpoint_data, main_metric, minimize_metric, valid_loader ?
        # epoch_metrics, valid_metrics ?

    if load_full:
        utils.unpack_checkpoint(
            checkpoint,
            model=runner.model,
            criterion=runner.criterion,
            optimizer=runner.optimizer,
            scheduler=runner.scheduler,
        )

        print(f"loaded state checkpoint {filename} "
              f"(global epoch {checkpoint['global_epoch']}, "
              f"epoch {checkpoint['epoch']}, "
              f"stage {checkpoint['stage_name']})")
    else:
        utils.unpack_checkpoint(
            checkpoint,
            model=runner.model,
        )

        print(f"loaded model checkpoint {filename}")