Beispiel #1
0
def test_state_dict_integration():
    engine = Engine(lambda e, b: 1)
    data = range(100)
    engine.run(data, max_epochs=10)
    sd = engine.state_dict()
    assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1
    assert sd["iteration"] == engine.state.iteration == 10 * 100
    assert sd["epoch_length"] == engine.state.epoch_length == 100
    assert sd["max_epochs"] == engine.state.max_epochs == 10
Beispiel #2
0
def test_state_dict_integration():
    engine = Engine(lambda e, b: 1)
    data = list(range(100))
    engine.run(data, max_epochs=10, seed=17)
    sd = engine.state_dict()
    assert isinstance(sd, Mapping) and len(sd) == 4
    assert sd['seed'] == engine.state.seed
    assert sd['iteration'] == engine.state.iteration == 10 * 100
    assert sd['epoch_length'] == engine.state.epoch_length == 100
    assert sd['max_epochs'] == engine.state.max_epochs == 10
Beispiel #3
0
def test_state_dict():
    engine = Engine(lambda e, b: 1)
    sd = engine.state_dict()
    assert isinstance(sd, Mapping) and len(sd) == 0

    def _test(state):
        engine.state = state
        sd = engine.state_dict()
        assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1
        assert sd["iteration"] == engine.state.iteration
        assert sd["epoch_length"] == engine.state.epoch_length
        assert sd["max_epochs"] == engine.state.max_epochs

    _test(State(iteration=500, epoch_length=1000, max_epochs=100))
    _test(State(epoch=5, epoch_length=1000, max_epochs=100))
Beispiel #4
0
def test_state_dict():
    engine = Engine(lambda e, b: 1)
    sd = engine.state_dict()
    assert isinstance(sd, Mapping) and len(sd) == 0

    def _test(state):
        engine.state = state
        sd = engine.state_dict()
        assert isinstance(sd, Mapping) and \
            len(sd) == len(engine._state_dict_all_req_keys) + 1
        assert sd['seed'] == engine.state.seed
        assert sd['iteration'] == engine.state.iteration
        assert sd['epoch_length'] == engine.state.epoch_length
        assert sd['max_epochs'] == engine.state.max_epochs

    _test(State(seed=0, iteration=500, epoch_length=1000, max_epochs=100))
    _test(State(seed=0, epoch=5, epoch_length=1000, max_epochs=100))
Beispiel #5
0
        def post_epoch_actions(trainer_instance: Engine):

            # evaluate model on validation set
            evaluator.run(val_loader)
            state_val_metrics = evaluator.state.metrics

            current_epoch: int = trainer_instance.state.epoch

            with tune.checkpoint_dir(current_epoch) as local_checkpoint_dir:
                # save model, optimizer and trainer checkpoints
                path = os.path.join(local_checkpoint_dir, "checkpoint")
                torch.save(
                    (model.state_dict(), optimizer.state_dict(),
                     trainer_instance.state_dict(), evaluator.state_dict()),
                    path)

            # report validation scores to ray-tune
            report_dict: dict = {
                **state_val_metrics, "done": current_epoch == epochs
            }

            tune.report(**report_dict)
Beispiel #6
0
    def attach(
        self,
        trainer: Engine,
        to_save: Mapping,
        output_transform: Callable = lambda output: output,
        num_iter: Optional[int] = None,
        end_lr: float = 10.0,
        step_mode: str = "exp",
        smooth_f: float = 0.05,
        diverge_th: float = 5.0,
    ):
        """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.

        Usage:

        .. code-block:: python

            to_save = {"model": model, "optimizer": optimizer}
            with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
                trainer_with_lr_finder.run(dataloader)`

        Args:
            trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers
                will be executed.
            to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running
                the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should
                implement `state_dict` and `load_state_dict` methods.
            output_transform (callable, optional): function that transforms the trainer's `state.output` after each
                iteration. It must return the loss of that iteration.
            num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will
                run for `trainer.state.epoch_length * trainer.state.max_epochs`.
            end_lr (float, optional): upper bound for lr search. Default, 10.0.
            step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial
                lr to `end_lr`. Default, "exp".
            smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05
            diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`.
                Default, 5.0.

        Note:
            lr_finder cannot be attached to more than one trainer at a time.

        Returns:
            trainer_with_lr_finder: trainer used for finding the lr
        """
        if not isinstance(to_save, Mapping):
            raise TypeError("Argument to_save should be a mapping, but given {}".format(type(to_save)))

        Checkpoint._check_objects(to_save, "state_dict")
        Checkpoint._check_objects(to_save, "load_state_dict")

        if "optimizer" not in to_save:
            raise ValueError("Mapping to_save should contain 'optimizer' key")

        if not isinstance(to_save["optimizer"], torch.optim.Optimizer):
            raise TypeError(
                "Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"]))
            )

        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError("smooth_f is outside the range [0, 1]")
        if diverge_th < 1:
            raise ValueError("diverge_th should be larger than 1")
        if step_mode not in ["exp", "linear"]:
            raise ValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode))
        if num_iter is not None:
            if not isinstance(num_iter, int):
                raise TypeError("if provided, num_iter should be an integer, but give {}".format(num_iter))
            if num_iter <= 0:
                raise ValueError("if provided, num_iter should be positive, but give {}".format(num_iter))

        # store to_save
        with tempfile.TemporaryDirectory() as tmpdirname:
            obj = {k: o.state_dict() for k, o in to_save.items()}
            # add trainer
            obj["trainer"] = trainer.state_dict()
            cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt"
            torch.save(obj, cache_filepath.as_posix())

            optimizer = to_save["optimizer"]
            # Attach handlers
            if not trainer.has_event_handler(self._run):
                trainer.add_event_handler(
                    Events.STARTED,
                    self._run,
                    optimizer,
                    output_transform,
                    num_iter,
                    end_lr,
                    step_mode,
                    smooth_f,
                    diverge_th,
                )
            if not trainer.has_event_handler(self._warning):
                trainer.add_event_handler(Events.COMPLETED, self._warning)
            if not trainer.has_event_handler(self._reset):
                trainer.add_event_handler(Events.COMPLETED, self._reset)

            yield trainer
            self._detach(trainer)
            # restore to_save and reset trainer's state
            obj = torch.load(cache_filepath.as_posix())
            trainer.load_state_dict(obj["trainer"])
            for k, o in obj.items():
                if k in to_save:
                    to_save[k].load_state_dict(o)
Beispiel #7
0
def test_empty_state_dict_load_state_dict():
    engine = Engine(lambda e, b: 1)
    sd = engine.state_dict()
    engine.load_state_dict(sd)