Esempio n. 1
0
    def _test(epoch_length=None):
        max_epochs = 10
        num_iters = 21
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs):
            batch_checker = BatchChecker(data,
                                         init_counter=resume_epoch *
                                         epoch_length)

            def update_fn(engine, batch):
                assert batch_checker.check(batch), \
                    "{} | {}: {} vs {}".format(
                        resume_epoch,
                        batch_checker.counter, batch_checker.true_batch, batch)

            engine = Engine(update_fn)
            resume_state_dict = dict(epoch=resume_epoch,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     seed=0)
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Esempio n. 2
0
    def _test(epoch_length=None):

        max_epochs = 5
        num_iters = 21
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                1, min(num_iters * max_epochs, epoch_length * max_epochs), 4):
            batch_checker = BatchChecker(data, init_counter=resume_iteration)

            def update_fn(engine, batch):
                assert batch_checker.check(batch), \
                    "{} | {}: {} vs {}".format(
                        resume_iteration,
                        batch_checker.counter, batch_checker.true_batch, batch)

            engine = Engine(update_fn)

            @engine.on(Events.EPOCH_COMPLETED)
            def check_iteration(engine):
                assert engine.state.iteration == batch_checker.counter

            resume_state_dict = {
                "iteration": resume_iteration,
                "max_epochs": max_epochs,
                "epoch_length": epoch_length,
                "seed": 0
            }
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Esempio n. 3
0
def test_load_state_dict_asserts():
    engine = Engine(lambda e, b: 1)

    with pytest.raises(TypeError,
                       match=r"Argument state_dict should be a dictionary"):
        engine.load_state_dict("123")

    with pytest.raises(ValueError, match=r"is absent in provided state_dict"):
        engine.load_state_dict({})

    with pytest.raises(ValueError,
                       match=r"state_dict should contain only one of"):
        engine.load_state_dict({"max_epochs": 100, "epoch_length": 120})

    with pytest.raises(ValueError,
                       match=r"state_dict should contain only one of"):
        engine.load_state_dict({
            "max_epochs": 100,
            "epoch_length": 120,
            "iteration": 12,
            "epoch": 123
        })

    engine = Engine(lambda e, b: 1)
    engine.state_dict_user_keys.append("alpha")
    with pytest.raises(ValueError, match=r"Required user state attribute"):
        engine.load_state_dict({
            "max_epochs": 100,
            "epoch_length": 120,
            "iteration": 12
        })
Esempio n. 4
0
def test_load_state_dict_integration():
    engine = Engine(lambda e, b: 1)

    state_dict = {"max_epochs": 100, "epoch_length": 120, "epoch": 5}

    engine.load_state_dict(state_dict)
    engine.add_event_handler(Events.ITERATION_COMPLETED, IterationCounter(5 * 120 + 1))
    engine.add_event_handler(Events.EPOCH_COMPLETED, EpochCounter(6))
    data = range(120)
    engine.run(data)
Esempio n. 5
0
def test_load_state_dict_asserts():
    engine = Engine(lambda e, b: 1)

    with pytest.raises(TypeError,
                       match=r"Argument state_dict should be a dictionary"):
        engine.load_state_dict("123")

    with pytest.raises(ValueError, match=r"is absent in provided state_dict"):
        engine.load_state_dict({})

    with pytest.raises(ValueError,
                       match=r"state_dict should contain only one of"):
        engine.load_state_dict({
            "seed": 0,
            "max_epochs": 100,
            "epoch_length": 120
        })

    with pytest.raises(ValueError,
                       match=r"state_dict should contain only one of"):
        engine.load_state_dict({
            "seed": 0,
            "max_epochs": 100,
            "epoch_length": 120,
            "iteration": 12,
            "epoch": 123
        })
Esempio n. 6
0
    def _test(epoch_length=None):
        max_epochs = 3
        batch_size = 4
        num_iters = 17

        def infinite_data_iterator():
            torch.manual_seed(0)
            while True:
                for _ in range(num_iters):
                    data = torch.randint(0,
                                         1000,
                                         size=(batch_size, ),
                                         device=device)
                    yield data

        if epoch_length is None:
            epoch_length = num_iters

        seen_batchs = []

        def update_fn(engine, batch):
            seen_batchs.append(batch)

        engine = Engine(update_fn)
        engine.run(infinite_data_iterator(),
                   max_epochs=max_epochs,
                   seed=12,
                   epoch_length=epoch_length)

        for resume_iteration in range(
                1, min(num_iters * max_epochs, epoch_length * max_epochs), 7):
            batch_checker = BatchChecker(seen_batchs,
                                         init_counter=resume_iteration)

            def update_fn(engine, batch):
                assert batch_checker.check(batch), "{} | {}: {} vs {}".format(
                    resume_iteration, batch_checker.counter,
                    batch_checker.true_batch, batch)

            engine = Engine(update_fn)
            resume_state_dict = dict(iteration=resume_iteration,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     seed=12)
            engine.load_state_dict(resume_state_dict)
            engine.run(infinite_data_iterator())
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs, "{} | {} vs {}".format(
                resume_iteration, engine.state.iteration,
                epoch_length * max_epochs)
Esempio n. 7
0
    def _test(with_load_state_dict=False):
        engine = Engine(lambda e, b: None)
        engine.state.alpha = 0.0
        engine.state.beta = 1.0

        if with_load_state_dict:
            engine.load_state_dict({"iteration": 3, "max_epochs": 5, "epoch_length": 5})

        @engine.on(Events.STARTED | Events.EPOCH_STARTED | Events.EPOCH_COMPLETED | Events.COMPLETED)
        def check_custom_attr():
            assert hasattr(engine.state, "alpha") and engine.state.alpha == 0.0
            assert hasattr(engine.state, "beta") and engine.state.beta == 1.0

        engine.run([0, 1, 2, 3, 4], max_epochs=5)
Esempio n. 8
0
def test_load_state_dict_integration(counter_factory):
    engine = Engine(lambda e, b: 1)

    state_dict = {
        "seed": 0,
        "max_epochs": 100,
        "epoch_length": 120,
        "epoch": 5
    }

    engine.load_state_dict(state_dict)
    engine.add_event_handler(Events.ITERATION_COMPLETED,
                             counter_factory('iter', 5 * 120 + 1))
    engine.add_event_handler(Events.EPOCH_COMPLETED,
                             counter_factory('epoch', 6))
    data = list(range(120))
    engine.run(data)
Esempio n. 9
0
def test_load_state_dict_with_params_overriding_integration():

    state_dict = {"max_epochs": 100, "epoch_length": 120, "epoch": 5}
    data = range(120)

    # Override max_epochs
    new_max_epochs = 10
    engine = Engine(lambda e, b: 1)
    engine.load_state_dict(state_dict)
    state = engine.run(data, max_epochs=new_max_epochs)
    assert state.max_epochs == new_max_epochs
    assert state.iteration == state_dict["epoch_length"] * new_max_epochs
    assert state.epoch == new_max_epochs

    with pytest.raises(
            ValueError,
            match=r"Argument max_epochs should be larger than the start epoch"
    ):
        engine.load_state_dict(state_dict)
        engine.run(data, max_epochs=3)

    # Override epoch_length
    with pytest.raises(
            ValueError,
            match=r"Argument epoch_length should be same as in the state"):
        engine.load_state_dict(state_dict)
        engine.run(data, epoch_length=90)
Esempio n. 10
0
    def attach(
        self,
        trainer: Engine,
        to_save: Mapping,
        output_transform: Callable = lambda output: output,
        num_iter: Optional[int] = None,
        end_lr: float = 10.0,
        step_mode: str = "exp",
        smooth_f: float = 0.05,
        diverge_th: float = 5.0,
    ):
        """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.

        Usage:

        .. code-block:: python

            to_save = {"model": model, "optimizer": optimizer}
            with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
                trainer_with_lr_finder.run(dataloader)`

        Args:
            trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers
                will be executed.
            to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running
                the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should
                implement `state_dict` and `load_state_dict` methods.
            output_transform (callable, optional): function that transforms the trainer's `state.output` after each
                iteration. It must return the loss of that iteration.
            num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will
                run for `trainer.state.epoch_length * trainer.state.max_epochs`.
            end_lr (float, optional): upper bound for lr search. Default, 10.0.
            step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial
                lr to `end_lr`. Default, "exp".
            smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05
            diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`.
                Default, 5.0.

        Note:
            lr_finder cannot be attached to more than one trainer at a time.

        Returns:
            trainer_with_lr_finder: trainer used for finding the lr
        """
        if not isinstance(to_save, Mapping):
            raise TypeError("Argument to_save should be a mapping, but given {}".format(type(to_save)))

        Checkpoint._check_objects(to_save, "state_dict")
        Checkpoint._check_objects(to_save, "load_state_dict")

        if "optimizer" not in to_save:
            raise ValueError("Mapping to_save should contain 'optimizer' key")

        if not isinstance(to_save["optimizer"], torch.optim.Optimizer):
            raise TypeError(
                "Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"]))
            )

        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError("smooth_f is outside the range [0, 1]")
        if diverge_th < 1:
            raise ValueError("diverge_th should be larger than 1")
        if step_mode not in ["exp", "linear"]:
            raise ValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode))
        if num_iter is not None:
            if not isinstance(num_iter, int):
                raise TypeError("if provided, num_iter should be an integer, but give {}".format(num_iter))
            if num_iter <= 0:
                raise ValueError("if provided, num_iter should be positive, but give {}".format(num_iter))

        # store to_save
        with tempfile.TemporaryDirectory() as tmpdirname:
            obj = {k: o.state_dict() for k, o in to_save.items()}
            # add trainer
            obj["trainer"] = trainer.state_dict()
            cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt"
            torch.save(obj, cache_filepath.as_posix())

            optimizer = to_save["optimizer"]
            # Attach handlers
            if not trainer.has_event_handler(self._run):
                trainer.add_event_handler(
                    Events.STARTED,
                    self._run,
                    optimizer,
                    output_transform,
                    num_iter,
                    end_lr,
                    step_mode,
                    smooth_f,
                    diverge_th,
                )
            if not trainer.has_event_handler(self._warning):
                trainer.add_event_handler(Events.COMPLETED, self._warning)
            if not trainer.has_event_handler(self._reset):
                trainer.add_event_handler(Events.COMPLETED, self._reset)

            yield trainer
            self._detach(trainer)
            # restore to_save and reset trainer's state
            obj = torch.load(cache_filepath.as_posix())
            trainer.load_state_dict(obj["trainer"])
            for k, o in obj.items():
                if k in to_save:
                    to_save[k].load_state_dict(o)
Esempio n. 11
0
def test_empty_state_dict_load_state_dict():
    engine = Engine(lambda e, b: 1)
    sd = engine.state_dict()
    engine.load_state_dict(sd)
Esempio n. 12
0
    def _test(epoch_length=None):
        max_epochs = 3
        batch_size = 4
        num_iters = 17
        data = torch.randint(0, 1000, size=(num_iters * batch_size, ))

        if epoch_length is None:
            epoch_length = num_iters

        for num_workers in [0, 4]:

            sampler = _setup_sampler(sampler_type, num_iters, batch_size)
            orig_dataloader = torch.utils.data.DataLoader(
                data,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory="cuda" in device,
                sampler=sampler,
                drop_last=True,
                shuffle=sampler is None)
            seen_batchs = []

            def update_fn(engine, batch):
                batch_to_device = batch.to(device)
                seen_batchs.append(batch)

            engine = Engine(update_fn)

            if sampler_type == "distributed":

                @engine.on(Events.EPOCH_STARTED)
                def _(engine):
                    sampler.set_epoch(engine.state.epoch)

            engine.run(orig_dataloader,
                       max_epochs=max_epochs,
                       seed=12,
                       epoch_length=epoch_length)

            for resume_iteration in range(
                    1, min(num_iters * max_epochs, epoch_length * max_epochs),
                    7):
                batch_checker = BatchChecker(seen_batchs,
                                             init_counter=resume_iteration)

                sampler = _setup_sampler(sampler_type, num_iters, batch_size)
                resume_dataloader = torch.utils.data.DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in device,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None)

                def update_fn(engine, batch):
                    batch_to_device = batch.to(device)
                    assert batch_checker.check(batch), \
                        "{} {} | {}: {} vs {}".format(
                            num_workers, resume_iteration,
                            batch_checker.counter, batch_checker.true_batch, batch)

                engine = Engine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch)

                resume_state_dict = dict(iteration=resume_iteration,
                                         max_epochs=max_epochs,
                                         epoch_length=epoch_length,
                                         seed=12)
                engine.load_state_dict(resume_state_dict)
                engine.run(resume_dataloader)
                assert engine.state.epoch == max_epochs
                assert engine.state.iteration == epoch_length * max_epochs, \
                    "{}, {} | {} vs {}".format(num_workers, resume_iteration,
                                               engine.state.iteration,
                                               epoch_length * max_epochs)