Example #1
0
def test_state_dict():
    engine = DeterministicEngine(lambda e, b: 1)
    sd = engine.state_dict()
    assert isinstance(sd, Mapping) and len(sd) == 4
    assert "iteration" in sd and sd["iteration"] == 0
    assert "max_epochs" in sd and sd["max_epochs"] is None
    assert "epoch_length" in sd and sd["epoch_length"] is None
    assert "rng_states" in sd and sd["rng_states"] is not None
Example #2
0
def test_engine_no_data_asserts():
    trainer = DeterministicEngine(lambda e, b: None)

    with pytest.raises(
            ValueError,
            match=
            r"Deterministic engine does not support the option of data=None"):
        trainer.run(max_epochs=10, epoch_length=10)
Example #3
0
def test_dataloader_no_dataset_kind():
    # tests issue : https://github.com/pytorch/ignite/issues/1022

    engine = DeterministicEngine(lambda e, b: None)

    data = torch.randint(0, 1000, size=(100 * 4, ))
    dataloader = DataLoader(data, batch_size=4)
    dataloader = OldDataLoader(dataloader)

    engine.run(dataloader)
Example #4
0
    def _test(epoch_length=None):
        max_epochs = 3
        batch_size = 4
        num_iters = 17

        def infinite_data_iterator():
            while True:
                for _ in range(num_iters):
                    data = torch.randint(0,
                                         1000,
                                         size=(batch_size, ),
                                         device=device)
                    yield data

        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                1, min(num_iters * max_epochs, epoch_length * max_epochs), 7):

            seen_batchs = []

            def update_fn(_, batch):
                seen_batchs.append(batch)

            engine = DeterministicEngine(update_fn)

            torch.manual_seed(24)
            engine.run(
                infinite_data_iterator(),
                max_epochs=max_epochs,
                epoch_length=epoch_length,
            )

            batch_checker = BatchChecker(seen_batchs,
                                         init_counter=resume_iteration)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(iteration=resume_iteration,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            torch.manual_seed(24)
            engine.run(infinite_data_iterator())
            assert engine.state.epoch == max_epochs
            assert (
                engine.state.iteration == epoch_length * max_epochs
            ), f"{resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
Example #5
0
    def _test(epoch_length=None):
        max_epochs = 5
        batch_size = 4
        num_iters = 21

        def infinite_data_iterator():
            while True:
                for _ in range(num_iters):
                    data = torch.randint(0,
                                         1000,
                                         size=(batch_size, ),
                                         device=device)
                    yield data

        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs):
            seen_batchs = []

            def update_fn(_, batch):
                # if there is a random op when using data batch etc, we can not resume correctly
                # torch.rand(1)
                seen_batchs.append(batch)

            engine = DeterministicEngine(update_fn)
            torch.manual_seed(121)
            engine.run(
                infinite_data_iterator(),
                max_epochs=max_epochs,
                epoch_length=epoch_length,
            )

            batch_checker = BatchChecker(seen_batchs,
                                         init_counter=resume_epoch *
                                         epoch_length)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(epoch=resume_epoch,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            torch.manual_seed(121)
            engine.run(infinite_data_iterator())
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Example #6
0
def test_engine_with_dataloader_no_auto_batching():
    # tests https://github.com/pytorch/ignite/issues/941

    data = torch.rand(64, 4, 10)
    data_loader = DataLoader(
        data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True)
    )

    counter = [0]

    def foo(e, b):
        print(f"{e.state.epoch}-{e.state.iteration}: {b}")
        counter[0] += 1

    engine = DeterministicEngine(foo)
    engine.run(data_loader, epoch_length=10, max_epochs=5)

    assert counter[0] == 50
Example #7
0
def test_concepts_snippet_warning():
    def random_train_data_generator():
        while True:
            yield torch.randint(0, 100, size=(1, ))

    def print_train_data(engine, batch):
        i = engine.state.iteration
        e = engine.state.epoch
        print("train", e, i, batch.tolist())

    trainer = DeterministicEngine(print_train_data)

    @trainer.on(Events.ITERATION_COMPLETED(every=3))
    def user_handler(_):
        # handler synchronizes the random state
        torch.manual_seed(12)
        a = torch.rand(1)

    trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5)
Example #8
0
def test_run_finite_iterator_no_epoch_length():
    # FR: https://github.com/pytorch/ignite/issues/871
    unknown_size = 11

    def finite_unk_size_data_iter():
        for i in range(unknown_size):
            yield i

    bc = BatchChecker(data=list(range(unknown_size)))

    engine = DeterministicEngine(lambda e, b: bc.check(b))

    @engine.on(Events.DATALOADER_STOP_ITERATION)
    def restart_iter():
        engine.state.dataloader = finite_unk_size_data_iter()

    data_iter = finite_unk_size_data_iter()
    engine.run(data_iter, max_epochs=5)

    assert engine.state.epoch == 5
    assert engine.state.iteration == unknown_size * 5
Example #9
0
    def _test(epoch_length=None):
        max_epochs = 10
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs):
            batch_checker = BatchChecker(data,
                                         init_counter=resume_epoch *
                                         epoch_length)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            resume_state_dict = dict(epoch=resume_epoch,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Example #10
0
    def _test(epoch_length=None):

        max_epochs = 5
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters, ))
        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                2, min(num_iters * max_epochs, epoch_length * max_epochs), 4):
            batch_checker = BatchChecker(data, init_counter=resume_iteration)

            def update_fn(_, batch):
                assert batch_checker.check(
                    batch
                ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

            engine = DeterministicEngine(update_fn)

            @engine.on(Events.EPOCH_COMPLETED)
            def check_iteration(_):
                assert engine.state.iteration == batch_checker.counter

            resume_state_dict = dict(iteration=resume_iteration,
                                     max_epochs=max_epochs,
                                     epoch_length=epoch_length,
                                     rng_states=None)
            engine.load_state_dict(resume_state_dict)
            engine.run(data)
            assert engine.state.epoch == max_epochs
            assert engine.state.iteration == epoch_length * max_epochs
Example #11
0
def test_engine_with_dataloader_no_auto_batching():
    # tests https://github.com/pytorch/ignite/issues/941
    from torch.utils.data import DataLoader, BatchSampler, RandomSampler

    data = torch.rand(64, 4, 10)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=None,
                                              sampler=BatchSampler(
                                                  RandomSampler(data),
                                                  batch_size=8,
                                                  drop_last=True))

    counter = [0]

    def foo(e, b):
        print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b))
        counter[0] += 1

    engine = DeterministicEngine(foo)
    engine.run(data_loader, epoch_length=10, max_epochs=5)

    assert counter[0] == 50
Example #12
0
def test_concepts_snippet_resume():

    # Commented imports required in the snippet
    # import torch
    # from torch.utils.data import DataLoader

    # from ignite.engine import DeterministicEngine
    # from ignite.utils import manual_seed

    seen_batches = []
    manual_seed(seed=15)

    def random_train_data_loader(size):
        data = torch.arange(0, size)
        return DataLoader(data, batch_size=4, shuffle=True)

    def print_train_data(engine, batch):
        i = engine.state.iteration
        e = engine.state.epoch
        print("train", e, i, batch.tolist())
        seen_batches.append(batch)

    trainer = DeterministicEngine(print_train_data)

    print("Original Run")
    manual_seed(56)
    trainer.run(random_train_data_loader(40), max_epochs=2, epoch_length=5)

    original_batches = list(seen_batches)
    seen_batches = []

    print("Resumed Run")
    trainer.load_state_dict({
        "epoch": 1,
        "epoch_length": 5,
        "max_epochs": 2,
        "rng_states": None
    })
    manual_seed(56)
    trainer.run(random_train_data_loader(40))

    resumed_batches = list(seen_batches)
    seen_batches = []
    for b1, b2 in zip(original_batches[5:], resumed_batches):
        assert (b1 == b2).all()
Example #13
0
def run(output_path, config):

    distributed = dist.is_available() and dist.is_initialized()
    rank = dist.get_rank() if distributed else 0

    manual_seed(config["seed"] + rank)

    # Setup dataflow, model, optimizer, criterion
    train_loader, test_loader = utils.get_dataflow(config, distributed)
    model, optimizer = utils.get_model_optimizer(config, distributed)
    criterion = nn.CrossEntropyLoss().to(utils.device)

    le = len(train_loader)
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    def train_step(engine, batch):

        x = convert_tensor(batch[0], device=utils.device, non_blocking=True)
        y = convert_tensor(batch[1], device=utils.device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    if config["deterministic"] and rank == 0:
        print("Setup deterministic trainer")
    trainer = Engine(train_step) if not config["deterministic"] else DeterministicEngine(train_step)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config["display_iters"],
        log_every_iters=10,
    )

    if rank == 0:
        # Setup Tensorboard logger - wrapper on SummaryWriter
        tb_logger = TensorboardLogger(log_dir=output_path)
        # Attach logger to the trainer and log trainer's metrics (stored in trainer.state.metrics) every iteration
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="train", metric_names=metric_names),
            event_name=Events.ITERATION_COMPLETED,
        )
        # log optimizer's parameters: "lr" every iteration
        tb_logger.attach(
            trainer, log_handler=OptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED
        )

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "accuracy": Accuracy(device=utils.device if distributed else None),
        "loss": Loss(criterion, device=utils.device if distributed else None),
    }

    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True)
    train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True)

    def run_validation(engine):
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED(every=config["validate_every"]), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        # Setup progress bar on evaluation engines
        if config["display_iters"]:
            ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False, desc="Test evaluation").attach(evaluator)

        # Let's log metrics of `train_evaluator` stored in `train_evaluator.state.metrics` when validation run is done
        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train", metric_names="all", global_step_transform=global_step_from_engine(trainer)
            ),
            event_name=Events.COMPLETED,
        )

        # Let's log metrics of `evaluator` stored in `evaluator.state.metrics` when validation run is done
        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test", metric_names="all", global_step_transform=global_step_from_engine(trainer)
            ),
            event_name=Events.COMPLETED,
        )

        # Store 3 best models by validation accuracy:
        common.save_best_model_by_val_score(
            output_path, evaluator, model=model, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test"
        )

        # Optionally log model gradients
        if config["log_model_grads_every"] is not None:
            tb_logger.attach(
                trainer,
                log_handler=GradsHistHandler(model, tag=model.__class__.__name__),
                event_name=Events.ITERATION_COMPLETED(every=config["log_model_grads_every"]),
            )

    # In order to check training resuming we can emulate a crash
    if config["crash_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"]))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(engine.state.iteration))

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()
Example #14
0
    def _train(save_iter=None, save_epoch=None, sd=None):
        w_norms = []
        grad_norms = []
        data = []
        chkpt = []

        manual_seed(12)
        arch = [
            nn.Conv2d(3, 10, 3),
            nn.ReLU(),
            nn.Conv2d(10, 10, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 2),
        ]
        if with_dropout:
            arch.insert(2, nn.Dropout2d())
            arch.insert(-2, nn.Dropout())

        model = nn.Sequential(*arch).to(device)
        opt = SGD(model.parameters(), lr=0.001)

        def proc_fn(e, b):
            from ignite.engine.deterministic import _get_rng_states, _repr_rng_state

            s = _repr_rng_state(_get_rng_states())
            model.train()
            opt.zero_grad()
            y = model(b.to(device))
            y.sum().backward()
            opt.step()
            if debug:
                print(trainer.state.iteration, trainer.state.epoch,
                      "proc_fn - b.shape", b.shape,
                      torch.norm(y).item(), s)

        trainer = DeterministicEngine(proc_fn)

        if save_iter is not None:
            ev = Events.ITERATION_COMPLETED(once=save_iter)
        elif save_epoch is not None:
            ev = Events.EPOCH_COMPLETED(once=save_epoch)
            save_iter = save_epoch * (data_size // batch_size)

        @trainer.on(ev)
        def save_chkpt(_):
            if debug:
                print(trainer.state.iteration, "save_chkpt")
            fp = dirname / "test.pt"
            from ignite.engine.deterministic import _repr_rng_state

            tsd = trainer.state_dict()
            if debug:
                print("->", _repr_rng_state(tsd["rng_states"]))
            torch.save([model.state_dict(), opt.state_dict(), tsd], fp)
            chkpt.append(fp)

        def log_event_filter(_, event):
            if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5:
                return True
            return False

        @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter))
        def write_data_grads_weights(e):
            x = e.state.batch
            i = e.state.iteration
            data.append([i, x.mean().item(), x.std().item()])

            total = [0.0, 0.0]
            out1 = []
            out2 = []
            for p in model.parameters():
                n1 = torch.norm(p).item()
                n2 = torch.norm(p.grad).item()
                out1.append(n1)
                out2.append(n2)
                total[0] += n1
                total[1] += n2
            w_norms.append([i, total[0]] + out1)
            grad_norms.append([i, total[1]] + out2)

        if sd is not None:
            sd = torch.load(sd)
            model.load_state_dict(sd[0])
            opt.load_state_dict(sd[1])
            from ignite.engine.deterministic import _repr_rng_state

            if debug:
                print("-->", _repr_rng_state(sd[2]["rng_states"]))
            trainer.load_state_dict(sd[2])

        manual_seed(32)
        trainer.run(random_train_data_loader(size=data_size), max_epochs=5)
        return {
            "sd": chkpt,
            "data": data,
            "grads": grad_norms,
            "weights": w_norms
        }
Example #15
0
    def _test(epoch_length=None):

        max_epochs = 5
        total_batch_size = 4
        num_iters = 21
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters * total_batch_size, ))

        if epoch_length is None:
            epoch_length = num_iters

        for resume_epoch in range(1, max_epochs, 2):

            for num_workers in [0, 4]:
                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)

                orig_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in device,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                seen_batchs = []

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    seen_batchs.append(batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch - 1)

                torch.manual_seed(87)
                engine.run(
                    orig_dataloader,
                    max_epochs=max_epochs,
                    epoch_length=epoch_length,
                )

                batch_checker = BatchChecker(seen_batchs,
                                             init_counter=resume_epoch *
                                             epoch_length)

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                resume_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in device,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    assert batch_checker.check(
                        batch), "{} {} | {}: {} vs {}".format(
                            num_workers, resume_epoch, batch_checker.counter,
                            batch_checker.true_batch, batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch - 1)

                resume_state_dict = dict(epoch=resume_epoch,
                                         max_epochs=max_epochs,
                                         epoch_length=epoch_length,
                                         rng_states=None)
                engine.load_state_dict(resume_state_dict)
                torch.manual_seed(87)
                engine.run(resume_dataloader)
                assert engine.state.epoch == max_epochs
                assert engine.state.iteration == epoch_length * max_epochs
Example #16
0
    def _test(epoch_length=None):
        max_epochs = 3
        total_batch_size = 4
        num_iters = 17
        torch.manual_seed(0)
        data = torch.randint(0, 1000, size=(num_iters * total_batch_size, ))

        if epoch_length is None:
            epoch_length = num_iters

        for resume_iteration in range(
                2, min(num_iters * max_epochs, epoch_length * max_epochs), 13):

            for num_workers in [0, 2]:

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                orig_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in torch.device(device).type,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )
                seen_batchs = []

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    seen_batchs.append(batch)

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch)

                torch.manual_seed(12)
                engine.run(orig_dataloader,
                           max_epochs=max_epochs,
                           epoch_length=epoch_length)

                batch_checker = BatchChecker(seen_batchs,
                                             init_counter=resume_iteration)

                sampler, batch_size = _setup_sampler(sampler_type, num_iters,
                                                     total_batch_size)
                resume_dataloader = DataLoader(
                    data,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory="cuda" in torch.device(device).type,
                    sampler=sampler,
                    drop_last=True,
                    shuffle=sampler is None,
                )

                def update_fn(_, batch):
                    batch_to_device = batch.to(device)
                    cfg_msg = f"{num_workers} {resume_iteration}"
                    assert batch_checker.check(
                        batch
                    ), f"{cfg_msg} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}"

                engine = DeterministicEngine(update_fn)

                if sampler_type == "distributed":

                    @engine.on(Events.EPOCH_STARTED)
                    def _(engine):
                        sampler.set_epoch(engine.state.epoch)

                resume_state_dict = dict(iteration=resume_iteration,
                                         max_epochs=max_epochs,
                                         epoch_length=epoch_length,
                                         rng_states=None)
                engine.load_state_dict(resume_state_dict)
                torch.manual_seed(12)
                engine.run(resume_dataloader)
                assert engine.state.epoch == max_epochs
                assert (
                    engine.state.iteration == epoch_length * max_epochs
                ), f"{num_workers}, {resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
Example #17
0
def test_dengine_setup_seed_div_by_zero():
    with pytest.raises(ValueError,
                       match=r"iter_counter should be positive value"):
        DeterministicEngine(lambda e, b: None)._setup_seed(iter_counter=0)
Example #18
0
def create_supervised_trainer(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: Union[Callable, torch.nn.Module],
    device: Optional[Union[str, torch.device]] = None,
    non_blocking: bool = False,
    prepare_batch: Callable = _prepare_batch,
    output_transform: Callable[[Any, Any, Any, torch.Tensor],
                               Any] = lambda x, y, y_pred, loss: loss.item(),
    deterministic: bool = False,
    amp_mode: Optional[str] = None,
    scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
    gradient_accumulation_steps: int = 1,
) -> Engine:
    """Factory function for creating a trainer for supervised models.

    Args:
        model: the model to train.
        optimizer: the optimizer to use.
        loss_fn: the loss function to use.
        device: device type specification (default: None).
            Applies to batches after starting the engine. Model *will not* be moved.
            Device can be CPU, GPU or TPU.
        non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
        deterministic: if True, returns deterministic engine of type
            :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
            (default: False).
        amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
            `torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
            using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
        scaler: GradScaler instance for gradient scaling if `torch>=1.6.0`
            and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
            If True, will create default GradScaler. If GradScaler instance is passed, it will be used instead.
            (default: False)
        gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
            (default: 1 (means no gradient accumulation))

    Returns:
        a trainer engine with supervised update function.

    Examples:

        Create a trainer

        .. code-block:: python

            from ignite.engine import create_supervised_trainer
            from ignite.utils import convert_tensor
            from ignite.contrib.handlers.tqdm_logger import ProgressBar

            model = ...
            loss = ...
            optimizer = ...
            dataloader = ...

            def prepare_batch_fn(batch, device, non_blocking):
                x = ...  # get x from batch
                y = ...  # get y from batch

                # return a tuple of (x, y) that can be directly runned as
                # `loss_fn(model(x), y)`
                return (
                    convert_tensor(x, device, non_blocking),
                    convert_tensor(y, device, non_blocking)
                )

            def output_transform_fn(x, y, y_pred, loss):
                # return only the loss is actually the default behavior for
                # trainer engine, but you can return anything you want
                return loss.item()

            trainer = create_supervised_trainer(
                model,
                optimizer,
                loss,
                prepare_batch=prepare_batch_fn,
                output_transform=output_transform_fn
            )

            pbar = ProgressBar()
            pbar.attach(trainer, output_transform=lambda x: {"loss": x})

            trainer.run(dataloader, max_epochs=5)

    Note:
        If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named
        ``scaler`` for that instance and can be used for saving and loading.

    Note:
        `engine.state.output` for this engine is defined by `output_transform` parameter and is the loss
        of the processed batch by default.

    .. warning::
        The internal use of `device` has changed.
        `device` will now *only* be used to move the input data to the correct device.
        The `model` should be moved by the user before creating an optimizer.
        For more information see:

        - `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
        - `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_

    .. warning::
        If ``amp_mode='apex'`` , the model(s) and optimizer(s) must be initialized beforehand
        since ``amp.initialize`` should be called after you have finished constructing your model(s)
        and optimizer(s), but before you send your model through any DistributedDataParallel wrapper.

        See more: https://nvidia.github.io/apex/amp.html#module-apex.amp

    .. versionchanged:: 0.4.5

        - Added ``amp_mode`` argument for automatic mixed precision.
        - Added ``scaler`` argument for gradient scaling.

    .. versionchanged:: 0.5.0
        Added Gradient Accumulation argument for all supervised training methods.
    """

    device_type = device.type if isinstance(device, torch.device) else device
    on_tpu = "xla" in device_type if device_type is not None else False
    mode, _scaler = _check_arg(on_tpu, amp_mode, scaler)

    if mode == "amp":
        _update = supervised_training_step_amp(
            model,
            optimizer,
            loss_fn,
            device,
            non_blocking,
            prepare_batch,
            output_transform,
            _scaler,
            gradient_accumulation_steps,
        )
    elif mode == "apex":
        _update = supervised_training_step_apex(
            model,
            optimizer,
            loss_fn,
            device,
            non_blocking,
            prepare_batch,
            output_transform,
            gradient_accumulation_steps,
        )
    elif mode == "tpu":
        _update = supervised_training_step_tpu(
            model,
            optimizer,
            loss_fn,
            device,
            non_blocking,
            prepare_batch,
            output_transform,
            gradient_accumulation_steps,
        )
    else:
        _update = supervised_training_step(
            model,
            optimizer,
            loss_fn,
            device,
            non_blocking,
            prepare_batch,
            output_transform,
            gradient_accumulation_steps,
        )

    trainer = Engine(_update) if not deterministic else DeterministicEngine(
        _update)
    if _scaler and scaler and isinstance(scaler, bool):
        trainer.state.scaler = _scaler  # type: ignore[attr-defined]

    return trainer
Example #19
0
def create_supervised_trainer(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: Union[Callable, torch.nn.Module],
    device: Optional[Union[str, torch.device]] = None,
    non_blocking: bool = False,
    prepare_batch: Callable = _prepare_batch,
    output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
    deterministic: bool = False,
) -> Engine:
    """
    Factory function for creating a trainer for supervised models.
    Args:
        model (`torch.nn.Module`): the model to train.
        optimizer (`torch.optim.Optimizer`): the optimizer to use.
        loss_fn (torch.nn loss function): the loss function to use.
        device (str, optional): device type specification (default: None).
            Applies to batches after starting the engine. Model *will not* be moved.
            Device can be CPU, GPU or TPU.
        non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
            with respect to the host. For other cases, this argument has no effect.
        prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
            tuple of tensors `(batch_x, batch_y)`.
        output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
            to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
        deterministic (bool, optional): if True, returns deterministic engine of type
            :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.Engine`
            (default: False).
    Note:
        `engine.state.output` for this engine is defined by `output_transform` parameter and is the loss
        of the processed batch by default.
    .. warning::
        The internal use of `device` has changed.
        `device` will now *only* be used to move the input data to the correct device.
        The `model` should be moved by the user before creating an optimizer.
        For more information see:
        * `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
        * `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
    Returns:
        Engine: a trainer engine with supervised update function.
    """

    device_type = device.type if isinstance(device, torch.device) else device
    on_tpu = "xla" in device_type if device_type is not None else False

    if on_tpu and not idist.has_xla_support:
        raise RuntimeError(
            "In order to run on TPU, please install PyTorch XLA")

    def _update(
            engine: Engine,
            batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
        model.train()
        optimizer.zero_grad()
        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()

        if on_tpu:
            xm.optimizer_step(optimizer, barrier=True)
        else:
            optimizer.step()

        return output_transform(x, y, y_pred, loss)

    trainer = Engine(_update) if not deterministic else DeterministicEngine(
        _update)

    return trainer