Ejemplo n.º 1
0
def configure_checkpoint_saving(trainer, evaluator, model, optimizer, args):
    to_save = {"model": model, "optimizer": optimizer}
    save_handler = DiskSaver(str(args.output_dir),
                             create_dir=False,
                             require_empty=False)

    # Configure epoch checkpoints.
    interval = 1 if args.dev_mode else min(5, args.max_epochs)
    checkpoint = Checkpoint(
        to_save,
        save_handler,
        n_saved=None,
        global_step_transform=lambda *_: trainer.state.epoch)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=interval),
                              checkpoint, evaluator)

    # Configure "best score" checkpoints.
    metric_name = "accuracy"
    best_checkpoint = Checkpoint(
        to_save,
        save_handler,
        score_name=metric_name,
        score_function=lambda engine: engine.state.metrics[metric_name],
        filename_prefix="best")
    trainer.add_event_handler(Events.EPOCH_COMPLETED, best_checkpoint,
                              evaluator)
Ejemplo n.º 2
0
def test_load_checkpoint_with_different_num_classes(dirname):
    model = DummyPretrainedModel()
    to_save_single_object = {"model": model}

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)

    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    handler(trainer, to_save_single_object)

    fname = handler.last_checkpoint
    loaded_checkpoint = torch.load(fname)

    to_load_single_object = {"pretrained_features": model.features}

    with pytest.raises(RuntimeError):
        Checkpoint.load_objects(to_load_single_object, loaded_checkpoint)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        Checkpoint.load_objects(to_load_single_object,
                                loaded_checkpoint,
                                strict=False,
                                blah="blah")

    loaded_weights = to_load_single_object["pretrained_features"].state_dict(
    )["weight"]

    assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights))
Ejemplo n.º 3
0
def test_checkpoint_wrong_input():

    with pytest.raises(TypeError,
                       match=r"Argument `to_save` should be a dictionary"):
        Checkpoint(
            12,
            lambda x: x,
            "prefix",
        )

    with pytest.raises(TypeError,
                       match=r"Argument `to_save` should be a dictionary"):
        Checkpoint([
            12,
        ], lambda x: x, "prefix")

    with pytest.raises(ValueError, match=r"No objects to checkpoint."):
        Checkpoint({}, lambda x: x, "prefix")

    model = DummyModel()
    to_save = {'model': model}

    with pytest.raises(TypeError,
                       match=r"Argument `save_handler` should be callable"):
        Checkpoint(to_save, 12, "prefix")

    with pytest.raises(
            ValueError,
            match=
            r"If `score_name` is provided, then `score_function` should be also provided."
    ):
        Checkpoint(to_save, lambda x: x, score_name="acc")
Ejemplo n.º 4
0
def _resume_training(resume_from: Union[str, Path], to_save: Dict[str, Any]):
    if resume_from:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f'Checkpoint "{checkpoint_fp}" is not found'
        print(f'Resuming from a checkpoint: {checkpoint_fp}')
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
Ejemplo n.º 5
0
    def _load_model(self):
        model = get_model(self.config)
        model.to(self.device)
        checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
        Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint)
        model.eval()

        return model
Ejemplo n.º 6
0
    def __call__(self, engine):
        checkpoint = torch.load(self.load_path)
        if len(self.load_dict) == 1:
            key = list(self.load_dict.keys())[0]
            if not (key in checkpoint):
                checkpoint = {key: checkpoint}

        Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint)
        self.logger.info(f"Restored all variables from {self.load_path}")
Ejemplo n.º 7
0
def gen_save_best_models_by_val_score(
    save_handler: Union[Callable, BaseSaveHandler],
    evaluator: Engine,
    models: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
    metric_name: str,
    n_saved: int = 3,
    trainer: Optional[Engine] = None,
    tag: str = "val",
    **kwargs: Any,
) -> Checkpoint:
    """Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
    (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
    Models with highest metric value will be retained. The logic of how to store objects is delegated to
    ``save_handler``.

    Args:
        save_handler (callable or :class:`~ignite.handlers.checkpoint.BaseSaveHandler`): Method or callable class to
            use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary
            and filename. If ``save_handler`` is callable class, it can
            inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement ``remove`` method
            to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk,
            ``save_handler`` can be defined with :class:`~ignite.handlers.DiskSaver`.
        evaluator (Engine): evaluation engine used to provide the score
        models (nn.Module or Mapping): model or dictionary with the object to save. Objects should have
            implemented ``state_dict`` and ``load_state_dict`` methods.
        metric_name (str): metric name to use for score evaluation. This metric should be present in
            `evaluator.state.metrics`.
        n_saved (int, optional): number of best models to store
        trainer (Engine, optional): trainer engine to fetch the epoch when saving the best model.
        tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
        **kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.

    Returns:
        A :class:`~ignite.handlers.Checkpoint` handler.
    """
    global_step_transform = None
    if trainer is not None:
        global_step_transform = global_step_from_engine(trainer)

    if isinstance(models, nn.Module):
        to_save = {"model": models}  # type: Dict[str, nn.Module]
    else:
        to_save = models

    best_model_handler = Checkpoint(
        to_save,
        save_handler,
        filename_prefix="best",
        n_saved=n_saved,
        global_step_transform=global_step_transform,
        score_name=f"{tag}_{metric_name.lower()}",
        score_function=Checkpoint.get_default_score_fn(metric_name),
        **kwargs,
    )
    evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

    return best_model_handler
Ejemplo n.º 8
0
 def resume(self):
     d = Path(self.save_path)
     pattern = "checkpoint_*.pth"
     saves = list(d.glob(pattern))
     if len(saves) == 0:
         raise FileNotFoundError("No checkpoint to load in %s" %
                                 (self.save_path))
     fp = max(saves, key=lambda f: f.stat().st_mtime)
     checkpoint = torch.load(fp)
     Checkpoint.load_objects(self.to_save(), checkpoint)
     print("Load trainer from %s" % fp)
Ejemplo n.º 9
0
def test_checkpoint__setup_checkpoint():
    save_handler = MagicMock()

    to_save = {'model1': DummyModel(), 'model2': DummyModel()}

    checkpointer = Checkpoint(to_save, save_handler=save_handler)
    chkpt = checkpointer._setup_checkpoint()
    assert isinstance(chkpt, dict)
    for k in ['model1', 'model2']:
        assert k in chkpt
        assert chkpt[k] == to_save[k].state_dict()
Ejemplo n.º 10
0
    def get_model(self, model_name, device, prefix='', path=None):
        if path is None:
            path = join(Constants.MODELS_PATH, model_name)
        best_file, best_loss = get_model_filename(path, prefix)
        model = copy.deepcopy(self.my_models[model_name])
        to_load = {'model': model}

        checkpoint = torch.load(join(path, best_file), map_location=device)
        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

        return model, best_loss, self.get_thresholds(model_name)
Ejemplo n.º 11
0
def test_checkpoint_load_state_dict():
    true_checkpointer = _setup_checkpoint()

    save_handler = MagicMock(spec=BaseSaveHandler)
    model = DummyModel()
    to_save = {"model": model}
    checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=None)

    sd = {"saved": [(0, "model_0.pt"), (10, "model_10.pt"), (20, "model_20.pt")]}
    checkpointer.load_state_dict(sd)
    assert checkpointer._saved == true_checkpointer._saved
Ejemplo n.º 12
0
def extract_model(ckp_file,
                  device='cuda' if torch.cuda.is_available() else 'cpu'):
    tokenizer = BertTokenizer.from_pretrained(base_model)
    model = BertClassificationModel(cls=tokenizer.vocab_size,
                                    model_file=base_model)

    to_load = {'BertClassificationModel': model}
    checkpoint = torch.load(ckp_file, map_location=device)

    Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
    model.to(device)

    model.bert.save_pretrained('./extract_bert/')
Ejemplo n.º 13
0
def create_trainer_and_evaluators(
    model: nn.Module,
    optimizer: Optimizer,
    criterion: nn.Module,
    data_loaders: Dict[str, DataLoader],
    metrics: Dict[str, Metric],
    config: ConfigSchema,
    logger: Logger,
) -> Tuple[Engine, Dict[str, Engine]]:
    trainer = get_trainer(model, criterion, optimizer)
    trainer.logger = logger

    evaluators = get_evaluators(model, metrics)
    setup_evaluation(trainer, evaluators, data_loaders, logger)

    lr_scheduler = get_lr_scheduler(config, optimizer, trainer, evaluators["val"])

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }

    common.setup_common_training_handlers(
        trainer=trainer,
        to_save=to_save,
        save_every_iters=config.checkpoint_every,
        save_handler=get_save_handler(config),
        with_pbars=False,
        train_sampler=data_loaders["train"].sampler,
    )
    trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler)
    ProgressBar(persist=False).attach(
        trainer,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=config.log_every_iters),
    )

    resume_from = config.resume_from
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix()
        )
        logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer, evaluators
Ejemplo n.º 14
0
def test_checkpoint_wrong_input():

    with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"):
        Checkpoint(12, lambda x: x, "prefix")

    with pytest.raises(TypeError, match=r"Argument `to_save` should be a dictionary"):
        Checkpoint([12], lambda x: x, "prefix")

    with pytest.raises(ValueError, match=r"No objects to checkpoint."):
        Checkpoint({}, lambda x: x, "prefix")

    model = DummyModel()
    to_save = {"model": model}

    with pytest.raises(TypeError, match=r"Argument `save_handler` should be callable"):
        Checkpoint(to_save, 12, "prefix")

    with pytest.raises(
        ValueError, match=r"If `score_name` is provided, then `score_function` should be also provided."
    ):
        Checkpoint(to_save, lambda x: x, score_name="acc")

    with pytest.raises(TypeError, match=r"global_step_transform should be a function."):
        Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", global_step_transform=123)

    with pytest.warns(UserWarning, match=r"Argument archived is deprecated"):
        Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", archived=True)

    with pytest.raises(ValueError, match=r"Cannot have key 'checkpointer' if `include_self` is True"):
        Checkpoint({"checkpointer": model}, lambda x: x, include_self=True)
Ejemplo n.º 15
0
def test_setup_filename_pattern():
    # default filename pattern
    assert Checkpoint.setup_filename_pattern() == "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"

    assert Checkpoint.setup_filename_pattern(False) == "{name}_{global_step}_{score_name}={score}.{ext}"
    assert Checkpoint.setup_filename_pattern(False, False, False) == "{name}_{global_step}.{ext}"
    assert Checkpoint.setup_filename_pattern(False, True, False) == "{name}_{global_step}_{score}.{ext}"
    assert Checkpoint.setup_filename_pattern(False, True, False, False) == "{name}_{score}.{ext}"
    assert Checkpoint.setup_filename_pattern(False, True, True, False) == "{name}_{score_name}={score}.{ext}"

    with pytest.raises(ValueError, match=r"At least one of with_score and with_global_step should be True."):
        Checkpoint.setup_filename_pattern(False, False, False, False)

    with pytest.raises(ValueError, match=r"If with_score_name is True, with_score should be also True"):
        Checkpoint.setup_filename_pattern(True, False, True, True)
Ejemplo n.º 16
0
    def create_callbacks(self):

        ## SETUP CALLBACKS
        print('[INFO] Creating callback functions for training loop...',
              end='')
        # Early Stopping - stops training if the validation loss does not decrease after 5 epochs
        handler = EarlyStopping(patience=self.config.EARLY_STOPPING_PATIENCE,
                                score_function=score_function_loss,
                                trainer=self.train_engine)
        self.evaluator.add_event_handler(Events.COMPLETED, handler)
        print('Early Stopping ({} epochs)...'.format(
            self.config.EARLY_STOPPING_PATIENCE),
              end='')

        val_checkpointer = Checkpoint(
            {"model": self.model},
            ClearMLSaver(),
            n_saved=1,
            score_function=score_function_acc,
            score_name="val_acc",
            filename_prefix='cub200_{}_ignite_best'.format(
                self.config.MODEL.MODEL_NAME),
            global_step_transform=global_step_from_engine(self.train_engine),
        )
        self.evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                         val_checkpointer)
        print('Model Checkpointing...', end='')
        print('Done')
Ejemplo n.º 17
0
    def _test(
        to_save,
        filename_prefix="",
        score_function=None,
        score_name=None,
        global_step_transform=None,
        filename_pattern=None,
    ):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            filename_prefix=filename_prefix,
            score_function=score_function,
            score_name=score_name,
            global_step_transform=global_step_transform,
            filename_pattern=filename_pattern,
        )

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=12, iteration=203, score=0.9999)

        checkpointer(trainer)
        return checkpointer.last_checkpoint
Ejemplo n.º 18
0
    def _test(to_save, obj, name, score_name=None):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(
            to_save, save_handler=save_handler, score_name=score_name, score_function=lambda e: e.state.epoch
        )

        if score_name is None:
            score_name = ""
        else:
            score_name += "="

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        metadata = {"basename": name, "score_name": score_name[:-1] if len(score_name) > 0 else None, "priority": 1}
        save_handler.assert_called_with(obj, "{}_{}1.pt".format(name, score_name), metadata)

        trainer.state.epoch = 12
        trainer.state.iteration = 1234

        checkpointer(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = 12
        save_handler.assert_called_with(obj, "{}_{}12.pt".format(name, score_name), metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_{}1.pt".format(name, score_name))
        assert checkpointer.last_checkpoint == "{}_{}12.pt".format(name, score_name)
 def load_trainer_from_checkpoint(self):
     if self.hparams.checkpoint_dir is not None:
         if not self.hparams.load_model_only:
             objects_to_checkpoint = {
                 "trainer": self.trainer,
                 "model": self.model, 
                 "optimizer": self.optimizer,
                 "scheduler": self.scheduler
             }
             if USE_AMP:
                 objects_to_checkpoint["amp"] = amp
         else:
             objects_to_checkpoint = {"model": self.model}
         objects_to_checkpoint = {k: v for k, v in objects_to_checkpoint.items() if v is not None}
         checkpoint = torch.load(self.hparams.checkpoint_dir, map_location="cpu")
         Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint)
Ejemplo n.º 20
0
def test_checkpoint_last_checkpoint_on_score():
    save_handler = MagicMock(spec=BaseSaveHandler)
    to_save = {"model": DummyModel()}

    checkpointer = Checkpoint(
        to_save,
        save_handler=save_handler,
        n_saved=None,
        score_name="val_acc",
        score_function=lambda e: e.state.metrics["val_acc"],
    )

    trainer = Engine(lambda e, b: None)

    val_acc = 0.0
    for i in range(10):
        val_acc = i * 0.1
        trainer.state = State(epoch=1,
                              iteration=i,
                              metrics={"val_acc": val_acc})
        checkpointer(trainer)

    assert save_handler.call_count == 10
    assert checkpointer.last_checkpoint == "{}_val_acc=0.9000.pt".format(
        "model")
Ejemplo n.º 21
0
    def _test(to_save, obj, name, score_name=None):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(to_save,
                                  save_handler=save_handler,
                                  score_name=score_name,
                                  score_function=lambda e: e.state.epoch)

        if score_name is None:
            score_name = ""
        else:
            score_name += "="

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj,
                                        "{}_{}1.pt".format(name, score_name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234

        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj,
                                        "{}_{}12.pt".format(name, score_name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_{}1.pt".format(
            name, score_name))
        assert checkpointer.last_checkpoint == "{}_{}12.pt".format(
            name, score_name)
Ejemplo n.º 22
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        trainer = Engine(lambda e, b: None)
        evaluator = Engine(lambda e, b: None)
        trainer.state = State(epoch=11, iteration=1)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            global_step_transform=lambda _1, _2: trainer.state.epoch,
            score_function=lambda e: e.state.metrics["val_acc"],
        )

        evaluator.state = State(epoch=1,
                                iteration=1000,
                                metrics={"val_acc": 0.77})
        checkpointer(evaluator)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj, "{}_11_0.7700.pt".format(name))

        trainer.state.epoch = 12
        evaluator.state.metrics["val_acc"] = 0.78

        checkpointer(evaluator)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj, "{}_12_0.7800.pt".format(name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_11_0.7700.pt".format(name))
        assert checkpointer.last_checkpoint == "{}_12_0.7800.pt".format(name)
Ejemplo n.º 23
0
    def _test(filename_prefix, to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            filename_prefix=filename_prefix,
            global_step_transform=lambda e, _: e.state.epoch,
        )

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        if len(filename_prefix) > 0:
            filename_prefix += "_"

        save_handler.assert_called_with(
            obj, "{}{}_1.pt".format(filename_prefix, name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(
            obj, "{}{}_12.pt".format(filename_prefix, name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}{}_1.pt".format(
            filename_prefix, name))
        assert checkpointer.last_checkpoint == "{}{}_12.pt".format(
            filename_prefix, name)
Ejemplo n.º 24
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(to_save,
                                  save_handler=save_handler,
                                  score_name="loss",
                                  score_function=lambda e: e.state.score)

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1, score=-0.77)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj, "{}_loss=-0.7700.pt".format(name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        trainer.state.score = -0.76

        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj, "{}_loss=-0.7600.pt".format(name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with(
            "{}_loss=-0.7700.pt".format(name))
        assert checkpointer.last_checkpoint == "{}_loss=-0.7600.pt".format(
            name)
Ejemplo n.º 25
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(to_save, save_handler=save_handler)
        assert checkpointer.last_checkpoint is None

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=0, iteration=0)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        metadata = {"basename": name, "score_name": None, "priority": 0}
        save_handler.assert_called_with(obj, "{}_0.pt".format(name), metadata)

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = 1234
        save_handler.assert_called_with(obj, "{}_1234.pt".format(name),
                                        metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_0.pt".format(name))
        assert checkpointer.last_checkpoint == "{}_1234.pt".format(name)
Ejemplo n.º 26
0
def test_clearml_disk_saver_integration_no_logger():
    model = torch.nn.Module()
    to_save_serializable = {"model": model}

    with pytest.warns(
            UserWarning,
            match="ClearMLSaver created a temporary checkpoints directory"):
        clearml.Task.current_task = Mock(return_value=object())
        clearml.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock(
        )
        clearml_saver = ClearMLSaver()
        checkpoint = Checkpoint(to_save=to_save_serializable,
                                save_handler=clearml_saver,
                                n_saved=1)

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    checkpoint(trainer)
    trainer.state.iteration = 1
    checkpoint(trainer)

    if clearml_saver._atomic:
        assert clearml.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2
    else:
        saved_files = list(os.listdir(clearml_saver.dirname))
        assert len(saved_files) == 1
        assert saved_files[0] == "model_1.pt"
Ejemplo n.º 27
0
    def _test(to_save, obj, name):
        save_handler = MagicMock()
        save_handler.remove = MagicMock()

        checkpointer = Checkpoint(to_save,
                                  save_handler=save_handler,
                                  score_function=lambda e: e.state.score)

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1, score=0.77)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj, "{}_0.77.pth".format(name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        trainer.state.score = 0.78

        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj, "{}_0.78.pth".format(name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_0.77.pth".format(name))
        assert checkpointer.last_checkpoint == "{}_0.78.pth".format(name)
Ejemplo n.º 28
0
def _test_checkpoint_load_objects_ddp(device):
    model = DummyModel().to(device)
    device_ids = (
        None if "cpu" in device.type else [device,]
    )
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
    opt = torch.optim.SGD(ddp_model.parameters(), lr=0.01)

    # single object:
    to_load = {"model": ddp_model}
    checkpoint = ddp_model.module.state_dict()
    Checkpoint.load_objects(to_load, checkpoint)

    # multiple objects:
    to_load = {"model": ddp_model, "opt": opt}
    checkpoint = {"model": ddp_model.module.state_dict(), "opt": opt.state_dict()}
    Checkpoint.load_objects(to_load, checkpoint)
def resume_from_checkpoint(to_save, conf, device=None):
    # type: (Dict[str, Any], DictConfig, Device) -> None
    to_load = {k: v for k, v in to_save.items() if v is not None}

    if conf.drop_state:
        # we might want to swap optimizer or to reset it state
        drop_keys = set(conf.drop_state)
        to_load = {k: v for k, v in to_load.items() if k not in drop_keys}

    checkpoint = torch.load(conf.load, map_location=device)
    ema_key = "model_ema"
    if ema_key in to_load and ema_key not in checkpoint:
        checkpoint[ema_key] = checkpoint["model"]
        logging.warning("There are no EMA weights in the checkpoint. "
                        "Using saved model weights as a starting point for the EMA.")

    Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Ejemplo n.º 30
0
def test_checkpoint_score_function_wrong_output():
    model = DummyModel()
    to_save = {'model': model}

    checkpointer = Checkpoint(to_save, lambda x: x, score_function=lambda e: {"1": 1}, score_name="acc")
    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    with pytest.raises(ValueError, match=r"Output of score_function should be a number"):
        checkpointer(trainer)