def _test_neptune_saver_integration(device): model = torch.nn.Module().to(device) to_save_serializable = {"model": model} mock_logger = None if idist.get_rank() == 0: mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_artifact = MagicMock() mock_logger.delete_artifacts = MagicMock() saver = NeptuneSaver(mock_logger) checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=saver, n_saved=1) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpoint(trainer) trainer.state.iteration = 1 checkpoint(trainer) if idist.get_rank() == 0: assert mock_logger.log_artifact.call_count == 2 assert mock_logger.delete_artifacts.call_count == 1
def resume_from( to_load: Mapping, checkpoint_fp: Union[str, Path], logger: Logger, strict: bool = True, model_dir: Optional[str] = None, ) -> None: """Loads state dict from a checkpoint file to resume the training. Parameters ---------- to_load a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, ...} checkpoint_fp path to the checkpoint file logger to log info about resuming from a checkpoint strict whether to strictly enforce that the keys in `state_dict` match the keys returned by this module’s `state_dict()` function. Default: True model_dir directory in which to save the object """ if isinstance(checkpoint_fp, str) and checkpoint_fp.startswith("https://"): checkpoint = torch.hub.load_state_dict_from_url( checkpoint_fp, model_dir=model_dir, map_location="cpu", check_hash=True ) else: if isinstance(checkpoint_fp, str): checkpoint_fp = Path(checkpoint_fp) if not checkpoint_fp.exists(): raise FileNotFoundError(f"Given {str(checkpoint_fp)} does not exist.") checkpoint = torch.load(checkpoint_fp, map_location="cpu") Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint, strict=strict) logger.info("Successfully resumed from a checkpoint: %s", checkpoint_fp)