def configure_checkpoint_saving(trainer, evaluator, model, optimizer, args): to_save = {"model": model, "optimizer": optimizer} save_handler = DiskSaver(str(args.output_dir), create_dir=False, require_empty=False) # Configure epoch checkpoints. interval = 1 if args.dev_mode else min(5, args.max_epochs) checkpoint = Checkpoint( to_save, save_handler, n_saved=None, global_step_transform=lambda *_: trainer.state.epoch) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=interval), checkpoint, evaluator) # Configure "best score" checkpoints. metric_name = "accuracy" best_checkpoint = Checkpoint( to_save, save_handler, score_name=metric_name, score_function=lambda engine: engine.state.metrics[metric_name], filename_prefix="best") trainer.add_event_handler(Events.EPOCH_COMPLETED, best_checkpoint, evaluator)
def test_load_checkpoint_with_different_num_classes(dirname): model = DummyPretrainedModel() to_save_single_object = {"model": model} trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1) handler(trainer, to_save_single_object) fname = handler.last_checkpoint loaded_checkpoint = torch.load(fname) to_load_single_object = {"pretrained_features": model.features} with pytest.raises(RuntimeError): Checkpoint.load_objects(to_load_single_object, loaded_checkpoint) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UserWarning) Checkpoint.load_objects(to_load_single_object, loaded_checkpoint, strict=False, blah="blah") loaded_weights = to_load_single_object["pretrained_features"].state_dict( )["weight"] assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights))
def test_checkpoint_wrong_input(): with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"): Checkpoint( 12, lambda x: x, "prefix", ) with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"): Checkpoint([ 12, ], lambda x: x, "prefix") with pytest.raises(ValueError, match=r"No objects to checkpoint."): Checkpoint({}, lambda x: x, "prefix") model = DummyModel() to_save = {'model': model} with pytest.raises(TypeError, match=r"Argument `save_handler` should be callable"): Checkpoint(to_save, 12, "prefix") with pytest.raises( ValueError, match= r"If `score_name` is provided, then `score_function` should be also provided." ): Checkpoint(to_save, lambda x: x, score_name="acc")
def _resume_training(resume_from: Union[str, Path], to_save: Dict[str, Any]): if resume_from: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), f'Checkpoint "{checkpoint_fp}" is not found' print(f'Resuming from a checkpoint: {checkpoint_fp}') checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
def _load_model(self): model = get_model(self.config) model.to(self.device) checkpoint = torch.load(self.checkpoint_path, map_location=self.device) Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint) model.eval() return model
def __call__(self, engine): checkpoint = torch.load(self.load_path) if len(self.load_dict) == 1: key = list(self.load_dict.keys())[0] if not (key in checkpoint): checkpoint = {key: checkpoint} Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint) self.logger.info(f"Restored all variables from {self.load_path}")
def gen_save_best_models_by_val_score( save_handler: Union[Callable, BaseSaveHandler], evaluator: Engine, models: Union[torch.nn.Module, Dict[str, torch.nn.Module]], metric_name: str, n_saved: int = 3, trainer: Optional[Engine] = None, tag: str = "val", **kwargs: Any, ) -> Checkpoint: """Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``). Models with highest metric value will be retained. The logic of how to store objects is delegated to ``save_handler``. Args: save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`): Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If ``save_handler`` is callable class, it can inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement ``remove`` method to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk, ``save_handler`` can be defined with :class:`~ignite.handlers.DiskSaver`. evaluator (Engine): evaluation engine used to provide the score models (nn.Module or Mapping): model or dictionary with the object to save. Objects should have implemented ``state_dict`` and ``load_state_dict`` methods. metric_name (str): metric name to use for score evaluation. This metric should be present in `evaluator.state.metrics`. n_saved (int, optional): number of best models to store trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model. tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val". **kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`. Returns: A :class:`~ignite.handlers.Checkpoint` handler. """ global_step_transform = None if trainer is not None: global_step_transform = global_step_from_engine(trainer) if isinstance(models, nn.Module): to_save = {"model": models} # type: Dict[str, nn.Module] else: to_save = models best_model_handler = Checkpoint( to_save, save_handler, filename_prefix="best", n_saved=n_saved, global_step_transform=global_step_transform, score_name=f"{tag}_{metric_name.lower()}", score_function=Checkpoint.get_default_score_fn(metric_name), **kwargs, ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler) return best_model_handler
def resume(self): d = Path(self.save_path) pattern = "checkpoint_*.pth" saves = list(d.glob(pattern)) if len(saves) == 0: raise FileNotFoundError("No checkpoint to load in %s" % (self.save_path)) fp = max(saves, key=lambda f: f.stat().st_mtime) checkpoint = torch.load(fp) Checkpoint.load_objects(self.to_save(), checkpoint) print("Load trainer from %s" % fp)
def test_checkpoint__setup_checkpoint(): save_handler = MagicMock() to_save = {'model1': DummyModel(), 'model2': DummyModel()} checkpointer = Checkpoint(to_save, save_handler=save_handler) chkpt = checkpointer._setup_checkpoint() assert isinstance(chkpt, dict) for k in ['model1', 'model2']: assert k in chkpt assert chkpt[k] == to_save[k].state_dict()
def get_model(self, model_name, device, prefix='', path=None): if path is None: path = join(Constants.MODELS_PATH, model_name) best_file, best_loss = get_model_filename(path, prefix) model = copy.deepcopy(self.my_models[model_name]) to_load = {'model': model} checkpoint = torch.load(join(path, best_file), map_location=device) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) return model, best_loss, self.get_thresholds(model_name)
def test_checkpoint_load_state_dict(): true_checkpointer = _setup_checkpoint() save_handler = MagicMock(spec=BaseSaveHandler) model = DummyModel() to_save = {"model": model} checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=None) sd = {"saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]} checkpointer.load_state_dict(sd) assert checkpointer._saved == true_checkpointer._saved
def extract_model(ckp_file, device='cuda' if torch.cuda.is_available() else 'cpu'): tokenizer = BertTokenizer.from_pretrained(base_model) model = BertClassificationModel(cls=tokenizer.vocab_size, model_file=base_model) to_load = {'BertClassificationModel': model} checkpoint = torch.load(ckp_file, map_location=device) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) model.to(device) model.bert.save_pretrained('./extract_bert/')
def create_trainer_and_evaluators( model: nn.Module, optimizer: Optimizer, criterion: nn.Module, data_loaders: Dict[str, DataLoader], metrics: Dict[str, Metric], config: ConfigSchema, logger: Logger, ) -> Tuple[Engine, Dict[str, Engine]]: trainer = get_trainer(model, criterion, optimizer) trainer.logger = logger evaluators = get_evaluators(model, metrics) setup_evaluation(trainer, evaluators, data_loaders, logger) lr_scheduler = get_lr_scheduler(config, optimizer, trainer, evaluators["val"]) to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, } common.setup_common_training_handlers( trainer=trainer, to_save=to_save, save_every_iters=config.checkpoint_every, save_handler=get_save_handler(config), with_pbars=False, train_sampler=data_loaders["train"].sampler, ) trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=config.log_every_iters), ) resume_from = config.resume_from if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer, evaluators
def test_checkpoint_wrong_input(): with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"): Checkpoint(12, lambda x: x, "prefix") with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"): Checkpoint([12], lambda x: x, "prefix") with pytest.raises(ValueError, match=r"No objects to checkpoint."): Checkpoint({}, lambda x: x, "prefix") model = DummyModel() to_save = {"model": model} with pytest.raises(TypeError, match=r"Argument `save_handler` should be callable"): Checkpoint(to_save, 12, "prefix") with pytest.raises( ValueError, match=r"If `score_name` is provided, then `score_function` should be also provided." ): Checkpoint(to_save, lambda x: x, score_name="acc") with pytest.raises(TypeError, match=r"global_step_transform should be a function."): Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", global_step_transform=123) with pytest.warns(UserWarning, match=r"Argument archived is deprecated"): Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", archived=True) with pytest.raises(ValueError, match=r"Cannot have key 'checkpointer' if `include_self` is True"): Checkpoint({"checkpointer": model}, lambda x: x, include_self=True)
def test_setup_filename_pattern(): # default filename pattern assert Checkpoint.setup_filename_pattern() == "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}" assert Checkpoint.setup_filename_pattern(False) == "{name}_{global_step}_{score_name}={score}.{ext}" assert Checkpoint.setup_filename_pattern(False, False, False) == "{name}_{global_step}.{ext}" assert Checkpoint.setup_filename_pattern(False, True, False) == "{name}_{global_step}_{score}.{ext}" assert Checkpoint.setup_filename_pattern(False, True, False, False) == "{name}_{score}.{ext}" assert Checkpoint.setup_filename_pattern(False, True, True, False) == "{name}_{score_name}={score}.{ext}" with pytest.raises(ValueError, match=r"At least one of with_score and with_global_step should be True."): Checkpoint.setup_filename_pattern(False, False, False, False) with pytest.raises(ValueError, match=r"If with_score_name is True, with_score should be also True"): Checkpoint.setup_filename_pattern(True, False, True, True)
def create_callbacks(self): ## SETUP CALLBACKS print('[INFO] Creating callback functions for training loop...', end='') # Early Stopping - stops training if the validation loss does not decrease after 5 epochs handler = EarlyStopping(patience=self.config.EARLY_STOPPING_PATIENCE, score_function=score_function_loss, trainer=self.train_engine) self.evaluator.add_event_handler(Events.COMPLETED, handler) print('Early Stopping ({} epochs)...'.format( self.config.EARLY_STOPPING_PATIENCE), end='') val_checkpointer = Checkpoint( {"model": self.model}, ClearMLSaver(), n_saved=1, score_function=score_function_acc, score_name="val_acc", filename_prefix='cub200_{}_ignite_best'.format( self.config.MODEL.MODEL_NAME), global_step_transform=global_step_from_engine(self.train_engine), ) self.evaluator.add_event_handler(Events.EPOCH_COMPLETED, val_checkpointer) print('Model Checkpointing...', end='') print('Done')
def _test( to_save, filename_prefix="", score_function=None, score_name=None, global_step_transform=None, filename_pattern=None, ): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint( to_save, save_handler=save_handler, filename_prefix=filename_prefix, score_function=score_function, score_name=score_name, global_step_transform=global_step_transform, filename_pattern=filename_pattern, ) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=12, iteration=203, score=0.9999) checkpointer(trainer) return checkpointer.last_checkpoint
def _test(to_save, obj, name, score_name=None): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint( to_save, save_handler=save_handler, score_name=score_name, score_function=lambda e: e.state.epoch ) if score_name is None: score_name = "" else: score_name += "=" trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 metadata = {"basename": name, "score_name": score_name[:-1] if len(score_name) > 0 else None, "priority": 1} save_handler.assert_called_with(obj, "{}_{}1.pt".format(name, score_name), metadata) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 metadata["priority"] = 12 save_handler.assert_called_with(obj, "{}_{}12.pt".format(name, score_name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_{}1.pt".format(name, score_name)) assert checkpointer.last_checkpoint == "{}_{}12.pt".format(name, score_name)
def load_trainer_from_checkpoint(self): if self.hparams.checkpoint_dir is not None: if not self.hparams.load_model_only: objects_to_checkpoint = { "trainer": self.trainer, "model": self.model, "optimizer": self.optimizer, "scheduler": self.scheduler } if USE_AMP: objects_to_checkpoint["amp"] = amp else: objects_to_checkpoint = {"model": self.model} objects_to_checkpoint = {k: v for k, v in objects_to_checkpoint.items() if v is not None} checkpoint = torch.load(self.hparams.checkpoint_dir, map_location="cpu") Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint)
def test_checkpoint_last_checkpoint_on_score(): save_handler = MagicMock(spec=BaseSaveHandler) to_save = {"model": DummyModel()} checkpointer = Checkpoint( to_save, save_handler=save_handler, n_saved=None, score_name="val_acc", score_function=lambda e: e.state.metrics["val_acc"], ) trainer = Engine(lambda e, b: None) val_acc = 0.0 for i in range(10): val_acc = i * 0.1 trainer.state = State(epoch=1, iteration=i, metrics={"val_acc": val_acc}) checkpointer(trainer) assert save_handler.call_count == 10 assert checkpointer.last_checkpoint == "{}_val_acc=0.9000.pt".format( "model")
def _test(to_save, obj, name, score_name=None): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler, score_name=score_name, score_function=lambda e: e.state.epoch) if score_name is None: score_name = "" else: score_name += "=" trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_{}1.pt".format(name, score_name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_{}12.pt".format(name, score_name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_{}1.pt".format( name, score_name)) assert checkpointer.last_checkpoint == "{}_{}12.pt".format( name, score_name)
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) trainer = Engine(lambda e, b: None) evaluator = Engine(lambda e, b: None) trainer.state = State(epoch=11, iteration=1) checkpointer = Checkpoint( to_save, save_handler=save_handler, global_step_transform=lambda _1, _2: trainer.state.epoch, score_function=lambda e: e.state.metrics["val_acc"], ) evaluator.state = State(epoch=1, iteration=1000, metrics={"val_acc": 0.77}) checkpointer(evaluator) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_11_0.7700.pt".format(name)) trainer.state.epoch = 12 evaluator.state.metrics["val_acc"] = 0.78 checkpointer(evaluator) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_12_0.7800.pt".format(name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_11_0.7700.pt".format(name)) assert checkpointer.last_checkpoint == "{}_12_0.7800.pt".format(name)
def _test(filename_prefix, to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint( to_save, save_handler=save_handler, filename_prefix=filename_prefix, global_step_transform=lambda e, _: e.state.epoch, ) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1) checkpointer(trainer) assert save_handler.call_count == 1 if len(filename_prefix) > 0: filename_prefix += "_" save_handler.assert_called_with( obj, "{}{}_1.pt".format(filename_prefix, name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with( obj, "{}{}_12.pt".format(filename_prefix, name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}{}_1.pt".format( filename_prefix, name)) assert checkpointer.last_checkpoint == "{}{}_12.pt".format( filename_prefix, name)
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler, score_name="loss", score_function=lambda e: e.state.score) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1, score=-0.77) checkpointer(trainer) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_loss=-0.7700.pt".format(name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 trainer.state.score = -0.76 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_loss=-0.7600.pt".format(name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with( "{}_loss=-0.7700.pt".format(name)) assert checkpointer.last_checkpoint == "{}_loss=-0.7600.pt".format( name)
def _test(to_save, obj, name): save_handler = MagicMock(spec=BaseSaveHandler) checkpointer = Checkpoint(to_save, save_handler=save_handler) assert checkpointer.last_checkpoint is None trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpointer(trainer) assert save_handler.call_count == 1 metadata = {"basename": name, "score_name": None, "priority": 0} save_handler.assert_called_with(obj, "{}_0.pt".format(name), metadata) trainer.state.epoch = 12 trainer.state.iteration = 1234 checkpointer(trainer) assert save_handler.call_count == 2 metadata["priority"] = 1234 save_handler.assert_called_with(obj, "{}_1234.pt".format(name), metadata) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_0.pt".format(name)) assert checkpointer.last_checkpoint == "{}_1234.pt".format(name)
def test_clearml_disk_saver_integration_no_logger(): model = torch.nn.Module() to_save_serializable = {"model": model} with pytest.warns( UserWarning, match="ClearMLSaver created a temporary checkpoints directory"): clearml.Task.current_task = Mock(return_value=object()) clearml.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock( ) clearml_saver = ClearMLSaver() checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=clearml_saver, n_saved=1) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpoint(trainer) trainer.state.iteration = 1 checkpoint(trainer) if clearml_saver._atomic: assert clearml.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2 else: saved_files = list(os.listdir(clearml_saver.dirname)) assert len(saved_files) == 1 assert saved_files[0] == "model_1.pt"
def _test(to_save, obj, name): save_handler = MagicMock() save_handler.remove = MagicMock() checkpointer = Checkpoint(to_save, save_handler=save_handler, score_function=lambda e: e.state.score) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=1, iteration=1, score=0.77) checkpointer(trainer) assert save_handler.call_count == 1 save_handler.assert_called_with(obj, "{}_0.77.pth".format(name)) trainer.state.epoch = 12 trainer.state.iteration = 1234 trainer.state.score = 0.78 checkpointer(trainer) assert save_handler.call_count == 2 save_handler.assert_called_with(obj, "{}_0.78.pth".format(name)) assert save_handler.remove.call_count == 1 save_handler.remove.assert_called_with("{}_0.77.pth".format(name)) assert checkpointer.last_checkpoint == "{}_0.78.pth".format(name)
def _test_checkpoint_load_objects_ddp(device): model = DummyModel().to(device) device_ids = ( None if "cpu" in device.type else [device,] ) ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids) opt = torch.optim.SGD(ddp_model.parameters(), lr=0.01) # single object: to_load = {"model": ddp_model} checkpoint = ddp_model.module.state_dict() Checkpoint.load_objects(to_load, checkpoint) # multiple objects: to_load = {"model": ddp_model, "opt": opt} checkpoint = {"model": ddp_model.module.state_dict(), "opt": opt.state_dict()} Checkpoint.load_objects(to_load, checkpoint)
def resume_from_checkpoint(to_save, conf, device=None): # type: (Dict[str, Any], DictConfig, Device) -> None to_load = {k: v for k, v in to_save.items() if v is not None} if conf.drop_state: # we might want to swap optimizer or to reset it state drop_keys = set(conf.drop_state) to_load = {k: v for k, v in to_load.items() if k not in drop_keys} checkpoint = torch.load(conf.load, map_location=device) ema_key = "model_ema" if ema_key in to_load and ema_key not in checkpoint: checkpoint[ema_key] = checkpoint["model"] logging.warning("There are no EMA weights in the checkpoint. " "Using saved model weights as a starting point for the EMA.") Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
def test_checkpoint_score_function_wrong_output(): model = DummyModel() to_save = {'model': model} checkpointer = Checkpoint(to_save, lambda x: x, score_function=lambda e: {"1": 1}, score_name="acc") trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) with pytest.raises(ValueError, match=r"Output of score_function should be a number"): checkpointer(trainer)