Ejemplo n.º 1
0
    def _test(epoch_length=None):

        max_epochs = 5
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                2, min(num_iters * max_epochs, epoch_length * max_epochs), 4):
            batch_checker = BatchChecker(data, init_counter=resume_iteration)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            @engine.on(Events.EPOCH_COMPLETED)
            def check_iteration(_):
                assert engine.state.iteration == batch_checker.counter

            resume_state_dict = dict(iteration=resume_iteration,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Ejemplo n.º 2
0
    def _test(epoch_length=None):
        max_epochs = 10
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs):
            batch_checker = BatchChecker(data,
                                         init_counter=resume_epoch *
                                         epoch_length)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(epoch=resume_epoch,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Ejemplo n.º 3
0
def test_engine_no_data_asserts():
    trainer = DeterministicEngine(lambda e, b: None)

    with pytest.raises(
            ValueError,
            match=
            r"Deterministic engine does not support the option of data=None"):
        trainer.run(max_epochs=10, epoch_length=10)
Ejemplo n.º 4
0
def test_dataloader_no_dataset_kind():
    # tests issue : https://github.com/pytorch/ignite/issues/1022

    engine = DeterministicEngine(lambda e, b: None)

    data = torch.randint(0, 1000, size=(100 * 4, ))
    dataloader = DataLoader(data, batch_size=4)
    dataloader = OldDataLoader(dataloader)

    engine.run(dataloader)
Ejemplo n.º 5
0
    def _test(epoch_length=None):
        max_epochs = 3
        batch_size = 4
        num_iters = 17

        def infinite_data_iterator():
            while True:
                for _ in range(num_iters):
                    data = torch.randint(0,
                                         1000,
                                         size=(batch_size, ),
                                         device=device)
                    yield data

        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                1, min(num_iters * max_epochs, epoch_length * max_epochs), 7):

            seen_batchs = []

            def update_fn(_, batch):
                seen_batchs.append(batch)

            engine = DeterministicEngine(update_fn)

            torch.manual_seed(24)
            engine.run(
                infinite_data_iterator(),
                max_epochs=max_epochs,
                epoch_length=epoch_length,
            )

            batch_checker = BatchChecker(seen_batchs,
                                         init_counter=resume_iteration)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(iteration=resume_iteration,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            torch.manual_seed(24)
            engine.run(infinite_data_iterator())
            assert engine.state.epoch == max_epochs
            assert (
                engine.state.iteration == epoch_length * max_epochs
            ), f"{resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
Ejemplo n.º 6
0
    def _test(epoch_length=None):
        max_epochs = 5
        batch_size = 4
        num_iters = 21

        def infinite_data_iterator():
            while True:
                for _ in range(num_iters):
                    data = torch.randint(0,
                                         1000,
                                         size=(batch_size, ),
                                         device=device)
                    yield data

        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs):
            seen_batchs = []

            def update_fn(_, batch):
                # if there is a random op when using data batch etc, we can not resume correctly
                # torch.rand(1)
                seen_batchs.append(batch)

            engine = DeterministicEngine(update_fn)
            torch.manual_seed(121)
            engine.run(
                infinite_data_iterator(),
                max_epochs=max_epochs,
                epoch_length=epoch_length,
            )

            batch_checker = BatchChecker(seen_batchs,
                                         init_counter=resume_epoch *
                                         epoch_length)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(epoch=resume_epoch,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            torch.manual_seed(121)
            engine.run(infinite_data_iterator())
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Ejemplo n.º 7
0
def test_concepts_snippet_resume():

    # Commented imports required in the snippet
    # import torch
    # from torch.utils.data import DataLoader

    # from ignite.engine import DeterministicEngine
    # from ignite.utils import manual_seed

    seen_batches = []
    manual_seed(seed=15)

    def random_train_data_loader(size):
        data = torch.arange(0, size)
        return DataLoader(data, batch_size=4, shuffle=True)

    def print_train_data(engine, batch):
        i = engine.state.iteration
        e = engine.state.epoch
        print("train", e, i, batch.tolist())
        seen_batches.append(batch)

    trainer = DeterministicEngine(print_train_data)

    print("Original Run")
    manual_seed(56)
    trainer.run(random_train_data_loader(40), max_epochs=2, epoch_length=5)

    original_batches = list(seen_batches)
    seen_batches = []

    print("Resumed Run")
    trainer.load_state_dict({
        "epoch": 1,
        "epoch_length": 5,
        "max_epochs": 2,
        "rng_states": None
    })
    manual_seed(56)
    trainer.run(random_train_data_loader(40))

    resumed_batches = list(seen_batches)
    seen_batches = []
    for b1, b2 in zip(original_batches[5:], resumed_batches):
        assert (b1 == b2).all()
Ejemplo n.º 8
0
def test_engine_with_dataloader_no_auto_batching():
    # tests https://github.com/pytorch/ignite/issues/941

    data = torch.rand(64, 4, 10)
    data_loader = DataLoader(
        data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True)
    )

    counter = [0]

    def foo(e, b):
        print(f"{e.state.epoch}-{e.state.iteration}: {b}")
        counter[0] += 1

    engine = DeterministicEngine(foo)
    engine.run(data_loader, epoch_length=10, max_epochs=5)

    assert counter[0] == 50
Ejemplo n.º 9
0
def test_concepts_snippet_warning():
    def random_train_data_generator():
        while True:
            yield torch.randint(0, 100, size=(1, ))

    def print_train_data(engine, batch):
        i = engine.state.iteration
        e = engine.state.epoch
        print("train", e, i, batch.tolist())

    trainer = DeterministicEngine(print_train_data)

    @trainer.on(Events.ITERATION_COMPLETED(every=3))
    def user_handler(_):
        # handler synchronizes the random state
        torch.manual_seed(12)
        a = torch.rand(1)

    trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5)
Ejemplo n.º 10
0
def test_run_finite_iterator_no_epoch_length():
    # FR: https://github.com/pytorch/ignite/issues/871
    unknown_size = 11

    def finite_unk_size_data_iter():
        for i in range(unknown_size):
            yield i

    bc = BatchChecker(data=list(range(unknown_size)))

    engine = DeterministicEngine(lambda e, b: bc.check(b))

    @engine.on(Events.DATALOADER_STOP_ITERATION)
    def restart_iter():
        engine.state.dataloader = finite_unk_size_data_iter()

    data_iter = finite_unk_size_data_iter()
    engine.run(data_iter, max_epochs=5)

    assert engine.state.epoch == 5
    assert engine.state.iteration == unknown_size * 5
Ejemplo n.º 11
0
def test_engine_with_dataloader_no_auto_batching():
    # tests https://github.com/pytorch/ignite/issues/941
    from torch.utils.data import DataLoader, BatchSampler, RandomSampler

    data = torch.rand(64, 4, 10)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=None,
                                              sampler=BatchSampler(
                                                  RandomSampler(data),
                                                  batch_size=8,
                                                  drop_last=True))

    counter = [0]

    def foo(e, b):
        print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b))
        counter[0] += 1

    engine = DeterministicEngine(foo)
    engine.run(data_loader, epoch_length=10, max_epochs=5)

    assert counter[0] == 50
Ejemplo n.º 12
0
    def _train(save_iter=None, save_epoch=None, sd=None):
        w_norms = []
        grad_norms = []
        data = []
        chkpt = []

        manual_seed(12)
        arch = [
            nn.Conv2d(3, 10, 3),
            nn.ReLU(),
            nn.Conv2d(10, 10, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 2),
        ]
        if with_dropout:
            arch.insert(2, nn.Dropout2d())
            arch.insert(-2, nn.Dropout())

        model = nn.Sequential(*arch).to(device)
        opt = SGD(model.parameters(), lr=0.001)

        def proc_fn(e, b):
            from ignite.engine.deterministic import _get_rng_states, _repr_rng_state

            s = _repr_rng_state(_get_rng_states())
            model.train()
            opt.zero_grad()
            y = model(b.to(device))
            y.sum().backward()
            opt.step()
            if debug:
                print(trainer.state.iteration, trainer.state.epoch,
                      "proc_fn - b.shape", b.shape,
                      torch.norm(y).item(), s)

        trainer = DeterministicEngine(proc_fn)

        if save_iter is not None:
            ev = Events.ITERATION_COMPLETED(once=save_iter)
        elif save_epoch is not None:
            ev = Events.EPOCH_COMPLETED(once=save_epoch)
            save_iter = save_epoch * (data_size // batch_size)

        @trainer.on(ev)
        def save_chkpt(_):
            if debug:
                print(trainer.state.iteration, "save_chkpt")
            fp = dirname / "test.pt"
            from ignite.engine.deterministic import _repr_rng_state

            tsd = trainer.state_dict()
            if debug:
                print("->", _repr_rng_state(tsd["rng_states"]))
            torch.save([model.state_dict(), opt.state_dict(), tsd], fp)
            chkpt.append(fp)

        def log_event_filter(_, event):
            if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5:
                return True
            return False

        @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter))
        def write_data_grads_weights(e):
            x = e.state.batch
            i = e.state.iteration
            data.append([i, x.mean().item(), x.std().item()])

            total = [0.0, 0.0]
            out1 = []
            out2 = []
            for p in model.parameters():
                n1 = torch.norm(p).item()
                n2 = torch.norm(p.grad).item()
                out1.append(n1)
                out2.append(n2)
                total[0] += n1
                total[1] += n2
            w_norms.append([i, total[0]] + out1)
            grad_norms.append([i, total[1]] + out2)

        if sd is not None:
            sd = torch.load(sd)
            model.load_state_dict(sd[0])
            opt.load_state_dict(sd[1])
            from ignite.engine.deterministic import _repr_rng_state

            if debug:
                print("-->", _repr_rng_state(sd[2]["rng_states"]))
            trainer.load_state_dict(sd[2])

        manual_seed(32)
        trainer.run(random_train_data_loader(size=data_size), max_epochs=5)
        return {
            "sd": chkpt,
            "data": data,
            "grads": grad_norms,
            "weights": w_norms
        }
Ejemplo n.º 13
0
    def _test(epoch_length=None):
        max_epochs = 3
        total_batch_size = 4
        num_iters = 17
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters * total_batch_size, ))

        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                2, min(num_iters * max_epochs, epoch_length * max_epochs), 13):

            for num_workers in [0, 2]:

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                orig_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in torch.device(device).type,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )
                seen_batchs = []

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    seen_batchs.append(batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch)

                torch.manual_seed(12)
                engine.run(orig_dataloader,
                           max_epochs=max_epochs,
                           epoch_length=epoch_length)

                batch_checker = BatchChecker(seen_batchs,
                                             init_counter=resume_iteration)

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                resume_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in torch.device(device).type,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    cfg_msg = f"{num_workers} {resume_iteration}"
                    assert batch_checker.check(
                        batch
                    ), f"{cfg_msg} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch)

                resume_state_dict = dict(iteration=resume_iteration,
                                         max_epochs=max_epochs,
                                         epoch_length=epoch_length,
                                         rng_states=None)
                engine.load_state_dict(resume_state_dict)
                torch.manual_seed(12)
                engine.run(resume_dataloader)
                assert engine.state.epoch == max_epochs
                assert (
                    engine.state.iteration == epoch_length * max_epochs
                ), f"{num_workers}, {resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
Ejemplo n.º 14
0
    def _test(epoch_length=None):

        max_epochs = 5
        total_batch_size = 4
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters * total_batch_size, ))

        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs, 2):

            for num_workers in [0, 4]:
                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)

                orig_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in device,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                seen_batchs = []

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    seen_batchs.append(batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch - 1)

                torch.manual_seed(87)
                engine.run(
                    orig_dataloader,
                    max_epochs=max_epochs,
                    epoch_length=epoch_length,
                )

                batch_checker = BatchChecker(seen_batchs,
                                             init_counter=resume_epoch *
                                             epoch_length)

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                resume_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in device,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    assert batch_checker.check(
                        batch), "{} {} | {}: {} vs {}".format(
                            num_workers, resume_epoch, batch_checker.counter,
                            batch_checker.true_batch, batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch - 1)

                resume_state_dict = dict(epoch=resume_epoch,
                                         max_epochs=max_epochs,
                                         epoch_length=epoch_length,
                                         rng_states=None)
                engine.load_state_dict(resume_state_dict)
                torch.manual_seed(87)
                engine.run(resume_dataloader)
                assert engine.state.epoch == max_epochs
                assert engine.state.iteration == epoch_length * max_epochs