def test_state_dict(): engine = DeterministicEngine(lambda e, b: 1) sd = engine.state_dict() assert isinstance(sd, Mapping) and len(sd) == 4 assert "iteration" in sd and sd["iteration"] == 0 assert "max_epochs" in sd and sd["max_epochs"] is None assert "epoch_length" in sd and sd["epoch_length"] is None assert "rng_states" in sd and sd["rng_states"] is not None
def test_engine_no_data_asserts(): trainer = DeterministicEngine(lambda e, b: None) with pytest.raises( ValueError, match= r"Deterministic engine does not support the option of data=None"): trainer.run(max_epochs=10, epoch_length=10)
def test_dataloader_no_dataset_kind(): # tests issue : https://github.com/pytorch/ignite/issues/1022 engine = DeterministicEngine(lambda e, b: None) data = torch.randint(0, 1000, size=(100 * 4, )) dataloader = DataLoader(data, batch_size=4) dataloader = OldDataLoader(dataloader) engine.run(dataloader)
def _test(epoch_length=None): max_epochs = 3 batch_size = 4 num_iters = 17 def infinite_data_iterator(): while True: for _ in range(num_iters): data = torch.randint(0, 1000, size=(batch_size, ), device=device) yield data if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 1, min(num_iters * max_epochs, epoch_length * max_epochs), 7): seen_batchs = [] def update_fn(_, batch): seen_batchs.append(batch) engine = DeterministicEngine(update_fn) torch.manual_seed(24) engine.run( infinite_data_iterator(), max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(24) engine.run(infinite_data_iterator()) assert engine.state.epoch == max_epochs assert ( engine.state.iteration == epoch_length * max_epochs ), f"{resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
def _test(epoch_length=None): max_epochs = 5 batch_size = 4 num_iters = 21 def infinite_data_iterator(): while True: for _ in range(num_iters): data = torch.randint(0, 1000, size=(batch_size, ), device=device) yield data if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs): seen_batchs = [] def update_fn(_, batch): # if there is a random op when using data batch etc, we can not resume correctly # torch.rand(1) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) torch.manual_seed(121) engine.run( infinite_data_iterator(), max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_epoch * epoch_length) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(121) engine.run(infinite_data_iterator()) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def test_engine_with_dataloader_no_auto_batching(): # tests https://github.com/pytorch/ignite/issues/941 data = torch.rand(64, 4, 10) data_loader = DataLoader( data, batch_size=None, sampler=BatchSampler(RandomSampler(data), batch_size=8, drop_last=True) ) counter = [0] def foo(e, b): print(f"{e.state.epoch}-{e.state.iteration}: {b}") counter[0] += 1 engine = DeterministicEngine(foo) engine.run(data_loader, epoch_length=10, max_epochs=5) assert counter[0] == 50
def test_concepts_snippet_warning(): def random_train_data_generator(): while True: yield torch.randint(0, 100, size=(1, )) def print_train_data(engine, batch): i = engine.state.iteration e = engine.state.epoch print("train", e, i, batch.tolist()) trainer = DeterministicEngine(print_train_data) @trainer.on(Events.ITERATION_COMPLETED(every=3)) def user_handler(_): # handler synchronizes the random state torch.manual_seed(12) a = torch.rand(1) trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5)
def test_run_finite_iterator_no_epoch_length(): # FR: https://github.com/pytorch/ignite/issues/871 unknown_size = 11 def finite_unk_size_data_iter(): for i in range(unknown_size): yield i bc = BatchChecker(data=list(range(unknown_size))) engine = DeterministicEngine(lambda e, b: bc.check(b)) @engine.on(Events.DATALOADER_STOP_ITERATION) def restart_iter(): engine.state.dataloader = finite_unk_size_data_iter() data_iter = finite_unk_size_data_iter() engine.run(data_iter, max_epochs=5) assert engine.state.epoch == 5 assert engine.state.iteration == unknown_size * 5
def _test(epoch_length=None): max_epochs = 10 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs): batch_checker = BatchChecker(data, init_counter=resume_epoch * epoch_length) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) engine.run(data) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def _test(epoch_length=None): max_epochs = 5 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 2, min(num_iters * max_epochs, epoch_length * max_epochs), 4): batch_checker = BatchChecker(data, init_counter=resume_iteration) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) @engine.on(Events.EPOCH_COMPLETED) def check_iteration(_): assert engine.state.iteration == batch_checker.counter resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) engine.run(data) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def test_engine_with_dataloader_no_auto_batching(): # tests https://github.com/pytorch/ignite/issues/941 from torch.utils.data import DataLoader, BatchSampler, RandomSampler data = torch.rand(64, 4, 10) data_loader = torch.utils.data.DataLoader(data, batch_size=None, sampler=BatchSampler( RandomSampler(data), batch_size=8, drop_last=True)) counter = [0] def foo(e, b): print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b)) counter[0] += 1 engine = DeterministicEngine(foo) engine.run(data_loader, epoch_length=10, max_epochs=5) assert counter[0] == 50
def test_concepts_snippet_resume(): # Commented imports required in the snippet # import torch # from torch.utils.data import DataLoader # from ignite.engine import DeterministicEngine # from ignite.utils import manual_seed seen_batches = [] manual_seed(seed=15) def random_train_data_loader(size): data = torch.arange(0, size) return DataLoader(data, batch_size=4, shuffle=True) def print_train_data(engine, batch): i = engine.state.iteration e = engine.state.epoch print("train", e, i, batch.tolist()) seen_batches.append(batch) trainer = DeterministicEngine(print_train_data) print("Original Run") manual_seed(56) trainer.run(random_train_data_loader(40), max_epochs=2, epoch_length=5) original_batches = list(seen_batches) seen_batches = [] print("Resumed Run") trainer.load_state_dict({ "epoch": 1, "epoch_length": 5, "max_epochs": 2, "rng_states": None }) manual_seed(56) trainer.run(random_train_data_loader(40)) resumed_batches = list(seen_batches) seen_batches = [] for b1, b2 in zip(original_batches[5:], resumed_batches): assert (b1 == b2).all()
def run(output_path, config): distributed = dist.is_available() and dist.is_initialized() rank = dist.get_rank() if distributed else 0 manual_seed(config["seed"] + rank) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = utils.get_dataflow(config, distributed) model, optimizer = utils.get_model_optimizer(config, distributed) criterion = nn.CrossEntropyLoss().to(utils.device) le = len(train_loader) milestones_values = [ (0, 0.0), (le * config["num_warmup_epochs"], config["learning_rate"]), (le * config["num_epochs"], 0.0), ] lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations def train_step(engine, batch): x = convert_tensor(batch[0], device=utils.device, non_blocking=True) y = convert_tensor(batch[1], device=utils.device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { "batch loss": loss.item(), } if config["deterministic"] and rank == 0: print("Setup deterministic trainer") trainer = Engine(train_step) if not config["deterministic"] else DeterministicEngine(train_step) train_sampler = train_loader.sampler if distributed else None to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], output_path=output_path, lr_scheduler=lr_scheduler, output_names=metric_names, with_pbar_on_iters=config["display_iters"], log_every_iters=10, ) if rank == 0: # Setup Tensorboard logger - wrapper on SummaryWriter tb_logger = TensorboardLogger(log_dir=output_path) # Attach logger to the trainer and log trainer's metrics (stored in trainer.state.metrics) every iteration tb_logger.attach( trainer, log_handler=OutputHandler(tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED, ) # log optimizer's parameters: "lr" every iteration tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED ) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "accuracy": Accuracy(device=utils.device if distributed else None), "loss": Loss(criterion, device=utils.device if distributed else None), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True) def run_validation(engine): train_evaluator.run(train_loader) evaluator.run(test_loader) trainer.add_event_handler(Events.EPOCH_STARTED(every=config["validate_every"]), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) if rank == 0: # Setup progress bar on evaluation engines if config["display_iters"]: ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) # Let's log metrics of `train_evaluator` stored in `train_evaluator.state.metrics` when validation run is done tb_logger.attach( train_evaluator, log_handler=OutputHandler( tag="train", metric_names="all", global_step_transform=global_step_from_engine(trainer) ), event_name=Events.COMPLETED, ) # Let's log metrics of `evaluator` stored in `evaluator.state.metrics` when validation run is done tb_logger.attach( evaluator, log_handler=OutputHandler( tag="test", metric_names="all", global_step_transform=global_step_from_engine(trainer) ), event_name=Events.COMPLETED, ) # Store 3 best models by validation accuracy: common.save_best_model_by_val_score( output_path, evaluator, model=model, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test" ) # Optionally log model gradients if config["log_model_grads_every"] is not None: tb_logger.attach( trainer, log_handler=GradsHistHandler(model, tag=model.__class__.__name__), event_name=Events.ITERATION_COMPLETED(every=config["log_model_grads_every"]), ) # In order to check training resuming we can emulate a crash if config["crash_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"])) def _(engine): raise Exception("STOP at iteration: {}".format(engine.state.iteration)) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix()) print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: import traceback print(traceback.format_exc()) if rank == 0: tb_logger.close()
def _train(save_iter=None, save_epoch=None, sd=None): w_norms = [] grad_norms = [] data = [] chkpt = [] manual_seed(12) arch = [ nn.Conv2d(3, 10, 3), nn.ReLU(), nn.Conv2d(10, 10, 3), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2), ] if with_dropout: arch.insert(2, nn.Dropout2d()) arch.insert(-2, nn.Dropout()) model = nn.Sequential(*arch).to(device) opt = SGD(model.parameters(), lr=0.001) def proc_fn(e, b): from ignite.engine.deterministic import _get_rng_states, _repr_rng_state s = _repr_rng_state(_get_rng_states()) model.train() opt.zero_grad() y = model(b.to(device)) y.sum().backward() opt.step() if debug: print(trainer.state.iteration, trainer.state.epoch, "proc_fn - b.shape", b.shape, torch.norm(y).item(), s) trainer = DeterministicEngine(proc_fn) if save_iter is not None: ev = Events.ITERATION_COMPLETED(once=save_iter) elif save_epoch is not None: ev = Events.EPOCH_COMPLETED(once=save_epoch) save_iter = save_epoch * (data_size // batch_size) @trainer.on(ev) def save_chkpt(_): if debug: print(trainer.state.iteration, "save_chkpt") fp = dirname / "test.pt" from ignite.engine.deterministic import _repr_rng_state tsd = trainer.state_dict() if debug: print("->", _repr_rng_state(tsd["rng_states"])) torch.save([model.state_dict(), opt.state_dict(), tsd], fp) chkpt.append(fp) def log_event_filter(_, event): if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5: return True return False @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter)) def write_data_grads_weights(e): x = e.state.batch i = e.state.iteration data.append([i, x.mean().item(), x.std().item()]) total = [0.0, 0.0] out1 = [] out2 = [] for p in model.parameters(): n1 = torch.norm(p).item() n2 = torch.norm(p.grad).item() out1.append(n1) out2.append(n2) total[0] += n1 total[1] += n2 w_norms.append([i, total[0]] + out1) grad_norms.append([i, total[1]] + out2) if sd is not None: sd = torch.load(sd) model.load_state_dict(sd[0]) opt.load_state_dict(sd[1]) from ignite.engine.deterministic import _repr_rng_state if debug: print("-->", _repr_rng_state(sd[2]["rng_states"])) trainer.load_state_dict(sd[2]) manual_seed(32) trainer.run(random_train_data_loader(size=data_size), max_epochs=5) return { "sd": chkpt, "data": data, "grads": grad_norms, "weights": w_norms }
def _test(epoch_length=None): max_epochs = 5 total_batch_size = 4 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters * total_batch_size, )) if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs, 2): for num_workers in [0, 4]: sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) orig_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None, ) seen_batchs = [] def update_fn(_, batch): batch_to_device = batch.to(device) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch - 1) torch.manual_seed(87) engine.run( orig_dataloader, max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_epoch * epoch_length) sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) resume_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None, ) def update_fn(_, batch): batch_to_device = batch.to(device) assert batch_checker.check( batch), "{} {} | {}: {} vs {}".format( num_workers, resume_epoch, batch_checker.counter, batch_checker.true_batch, batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch - 1) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(87) engine.run(resume_dataloader) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def _test(epoch_length=None): max_epochs = 3 total_batch_size = 4 num_iters = 17 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters * total_batch_size, )) if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 2, min(num_iters * max_epochs, epoch_length * max_epochs), 13): for num_workers in [0, 2]: sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) orig_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in torch.device(device).type, sampler=sampler, drop_last=True, shuffle=sampler is None, ) seen_batchs = [] def update_fn(_, batch): batch_to_device = batch.to(device) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch) torch.manual_seed(12) engine.run(orig_dataloader, max_epochs=max_epochs, epoch_length=epoch_length) batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) resume_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in torch.device(device).type, sampler=sampler, drop_last=True, shuffle=sampler is None, ) def update_fn(_, batch): batch_to_device = batch.to(device) cfg_msg = f"{num_workers} {resume_iteration}" assert batch_checker.check( batch ), f"{cfg_msg} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(12) engine.run(resume_dataloader) assert engine.state.epoch == max_epochs assert ( engine.state.iteration == epoch_length * max_epochs ), f"{num_workers}, {resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
def test_dengine_setup_seed_div_by_zero(): with pytest.raises(ValueError, match=r"iter_counter should be positive value"): DeterministicEngine(lambda e, b: None)._setup_seed(iter_counter=0)
def create_supervised_trainer( model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_fn: Union[Callable, torch.nn.Module], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(), deterministic: bool = False, amp_mode: Optional[str] = None, scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False, gradient_accumulation_steps: int = 1, ) -> Engine: """Factory function for creating a trainer for supervised models. Args: model: the model to train. optimizer: the optimizer to use. loss_fn: the loss function to use. device: device type specification (default: None). Applies to batches after starting the engine. Model *will not* be moved. Device can be CPU, GPU or TPU. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. deterministic: if True, returns deterministic engine of type :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine` (default: False). amp_mode: can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using `torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None) scaler: GradScaler instance for gradient scaling if `torch>=1.6.0` and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored. If True, will create default GradScaler. If GradScaler instance is passed, it will be used instead. (default: False) gradient_accumulation_steps: Number of steps the gradients should be accumulated across. (default: 1 (means no gradient accumulation)) Returns: a trainer engine with supervised update function. Examples: Create a trainer .. code-block:: python from ignite.engine import create_supervised_trainer from ignite.utils import convert_tensor from ignite.contrib.handlers.tqdm_logger import ProgressBar model = ... loss = ... optimizer = ... dataloader = ... def prepare_batch_fn(batch, device, non_blocking): x = ... # get x from batch y = ... # get y from batch # return a tuple of (x, y) that can be directly runned as # `loss_fn(model(x), y)` return ( convert_tensor(x, device, non_blocking), convert_tensor(y, device, non_blocking) ) def output_transform_fn(x, y, y_pred, loss): # return only the loss is actually the default behavior for # trainer engine, but you can return anything you want return loss.item() trainer = create_supervised_trainer( model, optimizer, loss, prepare_batch=prepare_batch_fn, output_transform=output_transform_fn ) pbar = ProgressBar() pbar.attach(trainer, output_transform=lambda x: {"loss": x}) trainer.run(dataloader, max_epochs=5) Note: If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named ``scaler`` for that instance and can be used for saving and loading. Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is the loss of the processed batch by default. .. warning:: The internal use of `device` has changed. `device` will now *only* be used to move the input data to the correct device. The `model` should be moved by the user before creating an optimizer. For more information see: - `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_ - `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_ .. warning:: If ``amp_mode='apex'`` , the model(s) and optimizer(s) must be initialized beforehand since ``amp.initialize`` should be called after you have finished constructing your model(s) and optimizer(s), but before you send your model through any DistributedDataParallel wrapper. See more: https://nvidia.github.io/apex/amp.html#module-apex.amp .. versionchanged:: 0.4.5 - Added ``amp_mode`` argument for automatic mixed precision. - Added ``scaler`` argument for gradient scaling. .. versionchanged:: 0.5.0 Added Gradient Accumulation argument for all supervised training methods. """ device_type = device.type if isinstance(device, torch.device) else device on_tpu = "xla" in device_type if device_type is not None else False mode, _scaler = _check_arg(on_tpu, amp_mode, scaler) if mode == "amp": _update = supervised_training_step_amp( model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, _scaler, gradient_accumulation_steps, ) elif mode == "apex": _update = supervised_training_step_apex( model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_accumulation_steps, ) elif mode == "tpu": _update = supervised_training_step_tpu( model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_accumulation_steps, ) else: _update = supervised_training_step( model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, gradient_accumulation_steps, ) trainer = Engine(_update) if not deterministic else DeterministicEngine( _update) if _scaler and scaler and isinstance(scaler, bool): trainer.state.scaler = _scaler # type: ignore[attr-defined] return trainer
def create_supervised_trainer( model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_fn: Union[Callable, torch.nn.Module], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False, prepare_batch: Callable = _prepare_batch, output_transform: Callable = lambda x, y, y_pred, loss: loss.item(), deterministic: bool = False, ) -> Engine: """ Factory function for creating a trainer for supervised models. Args: model (`torch.nn.Module`): the model to train. optimizer (`torch.optim.Optimizer`): the optimizer to use. loss_fn (torch.nn loss function): the loss function to use. device (str, optional): device type specification (default: None). Applies to batches after starting the engine. Model *will not* be moved. Device can be CPU, GPU or TPU. non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs tuple of tensors `(batch_x, batch_y)`. output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`. deterministic (bool, optional): if True, returns deterministic engine of type :class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.Engine` (default: False). Note: `engine.state.output` for this engine is defined by `output_transform` parameter and is the loss of the processed batch by default. .. warning:: The internal use of `device` has changed. `device` will now *only* be used to move the input data to the correct device. The `model` should be moved by the user before creating an optimizer. For more information see: * `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_ * `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_ Returns: Engine: a trainer engine with supervised update function. """ device_type = device.type if isinstance(device, torch.device) else device on_tpu = "xla" in device_type if device_type is not None else False if on_tpu and not idist.has_xla_support: raise RuntimeError( "In order to run on TPU, please install PyTorch XLA") def _update( engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: model.train() optimizer.zero_grad() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() if on_tpu: xm.optimizer_step(optimizer, barrier=True) else: optimizer.step() return output_transform(x, y, y_pred, loss) trainer = Engine(_update) if not deterministic else DeterministicEngine( _update) return trainer