Example #1
0
    def _test(epoch_length=None):

        max_epochs = 5
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                2, min(num_iters * max_epochs, epoch_length * max_epochs), 4):
            batch_checker = BatchChecker(data, init_counter=resume_iteration)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            @engine.on(Events.EPOCH_COMPLETED)
            def check_iteration(_):
                assert engine.state.iteration == batch_checker.counter

            resume_state_dict = dict(iteration=resume_iteration,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Example #2
0
    def _test(epoch_length=None):
        max_epochs = 10
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs):
            batch_checker = BatchChecker(data,
                                         init_counter=resume_epoch *
                                         epoch_length)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(epoch=resume_epoch,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Example #3
0
    def _test(epoch_length=None):
        max_epochs = 3
        batch_size = 4
        num_iters = 17

        def infinite_data_iterator():
            while True:
                for _ in range(num_iters):
                    data = torch.randint(0,
                                         1000,
                                         size=(batch_size, ),
                                         device=device)
                    yield data

        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                1, min(num_iters * max_epochs, epoch_length * max_epochs), 7):

            seen_batchs = []

            def update_fn(_, batch):
                seen_batchs.append(batch)

            engine = DeterministicEngine(update_fn)

            torch.manual_seed(24)
            engine.run(
                infinite_data_iterator(),
                max_epochs=max_epochs,
                epoch_length=epoch_length,
            )

            batch_checker = BatchChecker(seen_batchs,
                                         init_counter=resume_iteration)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(iteration=resume_iteration,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            torch.manual_seed(24)
            engine.run(infinite_data_iterator())
            assert engine.state.epoch == max_epochs
            assert (
                engine.state.iteration == epoch_length * max_epochs
            ), f"{resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
Example #4
0
def test_run_finite_iterator_no_epoch_length():
    # FR: https://github.com/pytorch/ignite/issues/871
    unknown_size = 11

    def finite_unk_size_data_iter():
        for i in range(unknown_size):
            yield i

    bc = BatchChecker(data=list(range(unknown_size)))

    engine = DeterministicEngine(lambda e, b: bc.check(b))

    @engine.on(Events.DATALOADER_STOP_ITERATION)
    def restart_iter():
        engine.state.dataloader = finite_unk_size_data_iter()

    data_iter = finite_unk_size_data_iter()
    engine.run(data_iter, max_epochs=5)

    assert engine.state.epoch == 5
    assert engine.state.iteration == unknown_size * 5
Example #5
0
def test_run_finite_iterator_no_epoch_length_2():
    # FR: https://github.com/pytorch/ignite/issues/871
    known_size = 11

    def finite_size_data_iter(size):
        for i in range(size):
            yield i

    bc = BatchChecker(data=list(range(known_size)))

    engine = Engine(lambda e, b: bc.check(b))

    @engine.on(Events.ITERATION_COMPLETED(every=known_size))
    def restart_iter():
        engine.state.dataloader = finite_size_data_iter(known_size)

    data_iter = finite_size_data_iter(known_size)
    engine.run(data_iter, max_epochs=5)

    assert engine.state.epoch == 5
    assert engine.state.iteration == known_size * 5
Example #6
0
    def _test(epoch_length=None):
        max_epochs = 5
        batch_size = 4
        num_iters = 21

        def infinite_data_iterator():
            while True:
                for _ in range(num_iters):
                    data = torch.randint(0,
                                         1000,
                                         size=(batch_size, ),
                                         device=device)
                    yield data

        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs):
            seen_batchs = []

            def update_fn(_, batch):
                # if there is a random op when using data batch etc, we can not resume correctly
                # torch.rand(1)
                seen_batchs.append(batch)

            engine = DeterministicEngine(update_fn)
            torch.manual_seed(121)
            engine.run(
                infinite_data_iterator(),
                max_epochs=max_epochs,
                epoch_length=epoch_length,
            )

            batch_checker = BatchChecker(seen_batchs,
                                         init_counter=resume_epoch *
                                         epoch_length)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(epoch=resume_epoch,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            torch.manual_seed(121)
            engine.run(infinite_data_iterator())
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Example #7
0
def test_run_once_finite_iterator_no_epoch_length():
    # FR: https://github.com/pytorch/ignite/issues/871

    unknown_size = 11

    def finite_unk_size_data_iter():
        for i in range(unknown_size):
            yield i

    bc = BatchChecker(data=list(range(unknown_size)))

    engine = Engine(lambda e, b: bc.check(b))

    completed_handler = MagicMock()
    engine.add_event_handler(Events.COMPLETED, completed_handler)

    data_iter = finite_unk_size_data_iter()
    engine.run(data_iter)

    assert engine.state.epoch == 1
    assert engine.state.iteration == unknown_size
    assert completed_handler.call_count == 1
Example #8
0
    def _test_as_iter(data, max_epochs, num_iters):

        batch_checker = BatchChecker(data)

        def update_fn(_, batch):
            assert batch_checker.check(batch), "{}: {} vs {}".format(
                batch_checker.counter, batch_checker.true_batch, batch)

        engine = Engine(update_fn)
        engine.run(iter(data), max_epochs=max_epochs, epoch_length=num_iters)
        if num_iters is None:
            num_iters = len(data)
        assert engine.state.iteration == num_iters * max_epochs
        assert engine.state.epoch == max_epochs
Example #9
0
    def _test(data, max_epochs, num_iters):

        batch_checker = BatchChecker(data)

        def update_fn(_, batch):
            assert batch_checker.check(
                batch
            ), f"{batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

        engine = Engine(update_fn)
        engine.run(data, max_epochs=max_epochs, epoch_length=num_iters)
        if num_iters is None:
            num_iters = len(data)
        assert engine.state.iteration == num_iters * max_epochs
        assert engine.state.epoch == max_epochs
Example #10
0
    def _test(epoch_length=None):
        max_epochs = 3
        total_batch_size = 4
        num_iters = 17
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters * total_batch_size, ))

        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                2, min(num_iters * max_epochs, epoch_length * max_epochs), 13):

            for num_workers in [0, 2]:

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                orig_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in torch.device(device).type,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )
                seen_batchs = []

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    seen_batchs.append(batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch)

                torch.manual_seed(12)
                engine.run(orig_dataloader,
                           max_epochs=max_epochs,
                           epoch_length=epoch_length)

                batch_checker = BatchChecker(seen_batchs,
                                             init_counter=resume_iteration)

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                resume_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in torch.device(device).type,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    cfg_msg = f"{num_workers} {resume_iteration}"
                    assert batch_checker.check(
                        batch
                    ), f"{cfg_msg} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch)

                resume_state_dict = dict(iteration=resume_iteration,
                                         max_epochs=max_epochs,
                                         epoch_length=epoch_length,
                                         rng_states=None)
                engine.load_state_dict(resume_state_dict)
                torch.manual_seed(12)
                engine.run(resume_dataloader)
                assert engine.state.epoch == max_epochs
                assert (
                    engine.state.iteration == epoch_length * max_epochs
                ), f"{num_workers}, {resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
Example #11
0
    def _test(epoch_length=None):

        max_epochs = 5
        total_batch_size = 4
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters * total_batch_size, ))

        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs, 2):

            for num_workers in [0, 4]:
                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)

                orig_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in device,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                seen_batchs = []

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    seen_batchs.append(batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch - 1)

                torch.manual_seed(87)
                engine.run(
                    orig_dataloader,
                    max_epochs=max_epochs,
                    epoch_length=epoch_length,
                )

                batch_checker = BatchChecker(seen_batchs,
                                             init_counter=resume_epoch *
                                             epoch_length)

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                resume_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in device,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    assert batch_checker.check(
                        batch), "{} {} | {}: {} vs {}".format(
                            num_workers, resume_epoch, batch_checker.counter,
                            batch_checker.true_batch, batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch - 1)

                resume_state_dict = dict(epoch=resume_epoch,
                                         max_epochs=max_epochs,
                                         epoch_length=epoch_length,
                                         rng_states=None)
                engine.load_state_dict(resume_state_dict)
                torch.manual_seed(87)
                engine.run(resume_dataloader)
                assert engine.state.epoch == max_epochs
                assert engine.state.iteration == epoch_length * max_epochs