def _load_states_from_file_map(*, runner: IRunner, load_map: Dict[str, str]) -> None: """ Load state of a model, criterion, optimizer, scheduler from files specified in ``load_map``. Arguments: runner (IRunner): current runner load_map (Dict[str, str]): dict with mappings to load. Expected keys - ``'model'``, ``'criterion'`` ``'optimizer'``, ``'scheduler'``, other keys will be ignored. Expected that values will be states (``'best'``, ``"best_full"``, ``"last"``, ``"last_full"``) or path to checkpoint. **NOTE:** for successful load criterion, optimizer, scheduler states required a full checkpoint. Raises: FileNotFoundError: when file/state specified in ``load_map`` is not exist. """ required_files = _required_files(runner.logdir, load_map) for filename in required_files.keys(): if not os.path.isfile(filename): raise FileNotFoundError(f"No checkpoint found at {filename}!") # extracting parts from files for filename, parts_to_load in required_files.items(): print(f"=> Loading {', '.join(parts_to_load)} from {filename}") checkpoint = utils.load_checkpoint(filename) to_unpack = {part: getattr(runner, part) for part in parts_to_load} utils.unpack_checkpoint(checkpoint, **to_unpack) print(f" loaded: {', '.join(parts_to_load)}")
def _load_checkpoint(*, filename, state: State): if not os.path.isfile(filename): raise Exception(f"No checkpoint found at {filename}") print(f"=> loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) if not state.stage_name.startswith("infer"): state.stage_name = checkpoint["stage_name"] state.epoch = checkpoint["epoch"] state.global_epoch = checkpoint["global_epoch"] # @TODO: should we also load, # checkpoint_data, main_metric, minimize_metric, valid_loader ? # epoch_metrics, valid_metrics ? utils.unpack_checkpoint( checkpoint, model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler, ) print( f"loaded checkpoint {filename} " f"(global epoch {checkpoint['global_epoch']}, " f"epoch {checkpoint['epoch']}, " f"stage {checkpoint['stage_name']})" )
def _load_checkpoint(*, filename, runner: IRunner, load_full: bool = True) -> None: """ Load checkpoint from a file. Arguments: filename (str): path to checkpoint runner (IRunner): current runner load_full (bool): if true (default) then will be performed loading states for criterion, optimizer and scheduler. File should contain keys required for loading model (``'model_state_dict'``), criterion (``'criterion_state_dict'``) (only for full load), optimizer (``'optimizer_state_dict'``), scheduler (``'scheduler_state_dict'``). Raises: FileNotFoundError: when file specified in ``filename`` is not exist. """ if not os.path.isfile(filename): raise FileNotFoundError(f"No checkpoint found at {filename}!") print(f"=> Loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) if not runner.stage_name.startswith("infer") and load_full: runner.stage_name = checkpoint["stage_name"] runner.epoch = checkpoint["epoch"] runner.global_epoch = checkpoint["global_epoch"] # @TODO: should we also load, # checkpoint_data, main_metric, minimize_metric, valid_loader ? # epoch_metrics, valid_metrics ? if load_full: utils.unpack_checkpoint( checkpoint, model=runner.model, criterion=runner.criterion, optimizer=runner.optimizer, scheduler=runner.scheduler, ) print(f"loaded state checkpoint {filename} " f"(global epoch {checkpoint['global_epoch']}, " f"epoch {checkpoint['epoch']}, " f"stage {checkpoint['stage_name']})") else: utils.unpack_checkpoint( checkpoint, model=runner.model, ) print(f"loaded model checkpoint {filename}")