def _test_distrib_config(local_rank, backend, ws, true_device, rank=None): assert idist.backend() == backend, f"{idist.backend()} vs {backend}" this_device = idist.device() assert isinstance(this_device, torch.device) if backend in ("nccl", "horovod") and "cuda" in this_device.type: true_device = torch.device(f"{true_device}:{local_rank}") assert this_device == true_device, f"{this_device} vs {true_device}" elif backend in ("gloo", "horovod"): assert this_device == torch.device(true_device) elif backend == "xla-tpu": assert true_device in this_device.type if rank is None: if idist.model_name() == "native-dist": rank = dist.get_rank() if rank is not None: assert idist.get_rank() == rank assert idist.get_world_size() == ws assert idist.get_local_rank() == local_rank assert idist.model_name() in ("native-dist", "xla-dist", "horovod-dist") _sanity_check()
def log_basic_info(self, logger): logger.info("- PyTorch version: {}".format(torch.__version__)) logger.info("- Ignite version: {}".format(ignite.__version__)) if idist.get_world_size() > 1: logger.info("\nDistributed setting:") logger.info("\tbackend: {}".format(idist.backend())) logger.info("\tworld size: {}".format(idist.get_world_size())) logger.info("\n")
def _test_func(index, ws, device, backend, true_init_method): assert 0 <= index < ws assert index == idist.get_local_rank() assert ws == idist.get_world_size() assert torch.device(device).type == idist.device().type assert backend == idist.backend() if idist.model_name() == "native-dist": from ignite.distributed.utils import _model assert _model._init_method == true_init_method
def log_basic_info(logger, config): msg = "\n- PyTorch version: {}".format(torch.__version__) msg += "\n- Ignite version: {}".format(ignite.__version__) logger.info(msg) if idist.get_world_size() > 1: msg = "\nDistributed setting:" msg += "\tbackend: {}".format(idist.backend()) msg += "\trank: {}".format(idist.get_rank()) msg += "\tworld size: {}".format(idist.get_world_size()) logger.info(msg)
def _mp_train(rank): # Specific ignite.distributed print( idist.get_rank(), "- backend=", idist.backend(), "- world size", idist.get_world_size(), "- device", idist.device(), ) print(idist.get_rank(), " with seed ", torch.initial_seed())
def log_basic_info(logger, config): logger.info("Train {} on CIFAR10".format(config["model"])) logger.info("- PyTorch version: {}".format(torch.__version__)) logger.info("- Ignite version: {}".format(ignite.__version__)) logger.info("\n") logger.info("Configuration:") for key, value in config.items(): logger.info("\t{}: {}".format(key, value)) logger.info("\n") if idist.get_world_size() > 1: logger.info("\nDistributed setting:") logger.info("\tbackend: {}".format(idist.backend())) logger.info("\tworld size: {}".format(idist.get_world_size())) logger.info("\n")
def log_basic_info(logger: Logger, config: ConfigSchema): logger.info("Experiment: {}".format(config.experiment_name)) logger.info("- PyTorch version: {}".format(torch.__version__)) logger.info("- Ignite version: {}".format(ignite.__version__)) logger.info("\n") logger.info("Configuration:") for line in OmegaConf.to_yaml(config).split("\n"): logger.info("\t" + line) logger.info("\n") if idist.get_world_size() > 1: logger.info("\nDistributed setting:") logger.info("\tbackend: {}".format(idist.backend())) logger.info("\tworld size: {}".format(idist.get_world_size())) logger.info("\n")
def _test_auto_model(model, ws, device, sync_bn=False, **kwargs): model = auto_model(model, sync_bn=sync_bn, **kwargs) bnd = idist.backend() if ws > 1 and torch.device(device).type in ("cuda", "cpu"): if idist.has_native_dist_support and bnd in ("nccl", "gloo"): assert isinstance(model, nn.parallel.DistributedDataParallel) if sync_bn: assert any([isinstance(m, nn.SyncBatchNorm) for m in model.modules()]) if "find_unused_parameters" in kwargs: assert model.find_unused_parameters == kwargs["find_unused_parameters"] elif idist.has_hvd_support and bnd in ("horovod",): assert isinstance(model, nn.Module) elif device != "cpu" and torch.cuda.is_available() and torch.cuda.device_count() > 1: assert isinstance(model, nn.parallel.DataParallel) else: assert isinstance(model, nn.Module) assert all( [p.device.type == torch.device(device).type for p in model.parameters()] ), f"{[p.device.type for p in model.parameters()]} vs {torch.device(device).type}"
def _test_auto_model(model, ws, device, sync_bn=False): model = auto_model(model, sync_bn=sync_bn) bnd = idist.backend() if ws > 1 and device in ("cuda", "cpu"): if idist.has_native_dist_support and bnd in ("nccl" or "gloo"): assert isinstance(model, nn.parallel.DistributedDataParallel) if sync_bn: assert any( [isinstance(m, nn.SyncBatchNorm) for m in model.modules()]) elif idist.has_hvd_support and bnd in ("horovod", ): assert isinstance(model, nn.Module) elif device != "cpu" and torch.cuda.is_available( ) and torch.cuda.device_count() > 1: assert isinstance(model, nn.parallel.DataParallel) else: assert isinstance(model, nn.Module) assert all([p.device.type == device for p in model.parameters()]), "{} vs {}".format( [p.device.type for p in model.parameters()], device)
def _test_auto_model_optimizer(ws, device): # Test auto_model model = nn.Linear(10, 10) _test_auto_model(model, ws, device) model = nn.Sequential(nn.Linear(20, 100), nn.BatchNorm1d(100)) _test_auto_model(model, ws, device, sync_bn="cuda" in torch.device(device).type) if ws > 1: _test_auto_model(model, ws, device, find_unused_parameters=True) _test_auto_model(model, ws, device, find_unused_parameters=False) # Test auto_optim bnd = idist.backend() optimizer = optim.SGD(model.parameters(), lr=0.01) optimizer = auto_optim(optimizer) if idist.has_xla_support and "xla" in device: assert isinstance(optimizer, optim.SGD) and hasattr(optimizer, "wrapped_optimizer") elif idist.has_hvd_support and bnd in ("horovod",): assert isinstance(optimizer, optim.SGD) and hasattr(optimizer, "_allreduce_grad_async") else: assert isinstance(optimizer, optim.SGD) and not hasattr(optimizer, "wrapped_optimizer")
def test_no_distrib(capsys): from ignite.distributed.utils import _model print("test_no_distrib : dist: ", dist.is_available()) print("test_no_distrib : _model", type(_model)) assert idist.backend() is None if torch.cuda.is_available(): assert idist.device().type == "cuda" else: assert idist.device().type == "cpu" assert idist.get_rank() == 0 assert idist.get_world_size() == 1 assert idist.get_local_rank() == 0 assert idist.model_name() == "serial" from ignite.distributed.utils import _model, _SerialModel _sanity_check() assert isinstance(_model, _SerialModel) idist.show_config() captured = capsys.readouterr() out = captured.err.split("\r") out = list(map(lambda x: x.strip(), out)) out = list(filter(None, out)) assert "ignite.distributed.utils INFO: distributed configuration: serial" in out[ -1] assert "ignite.distributed.utils INFO: backend: None" in out[-1] if torch.cuda.is_available(): assert "ignite.distributed.utils INFO: device: cuda" in out[-1] else: assert "ignite.distributed.utils INFO: device: cpu" in out[-1] assert "ignite.distributed.utils INFO: rank: 0" in out[-1] assert "ignite.distributed.utils INFO: local rank: 0" in out[-1] assert "ignite.distributed.utils INFO: world size: 1" in out[-1]
def log_basic_info(logger: Logger, config: Any) -> None: """Logging about pytorch, ignite, configurations, gpu system distributed settings. Parameters ---------- logger Logger instance for logging config config object to log """ import ignite logger.info("PyTorch version: %s", torch.__version__) logger.info("Ignite version: %s", ignite.__version__) if torch.cuda.is_available(): # explicitly import cudnn as # torch.backends.cudnn can not be pickled with hvd spawning procs from torch.backends import cudnn logger.info("GPU device: %s", torch.cuda.get_device_name(idist.get_local_rank())) logger.info("CUDA version: %s", torch.version.cuda) logger.info("CUDNN version: %s", cudnn.version()) logger.info("Configuration: %s", pformat(vars(config))) if idist.get_world_size() > 1: logger.info("distributed configuration: %s", idist.model_name()) logger.info("backend: %s", idist.backend()) logger.info("device: %s", idist.device().type) logger.info("hostname: %s", idist.hostname()) logger.info("world size: %s", idist.get_world_size()) logger.info("rank: %s", idist.get_rank()) logger.info("local rank: %s", idist.get_local_rank()) logger.info("num processes per node: %s", idist.get_nproc_per_node()) logger.info("num nodes: %s", idist.get_nnodes()) logger.info("node rank: %s", idist.get_node_rank())
def create_trainer( train_step, output_names, model, ema_model, optimizer, lr_scheduler, supervised_train_loader, test_loader, cfg, logger, cta=None, unsup_train_loader=None, cta_probe_loader=None, ): trainer = Engine(train_step) trainer.logger = logger output_path = os.getcwd() to_save = { "model": model, "ema_model": ema_model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler, } if cta is not None: to_save["cta"] = cta common.setup_common_training_handlers( trainer, train_sampler=supervised_train_loader.sampler, to_save=to_save, save_every_iters=cfg.solver.checkpoint_every, output_path=output_path, output_names=output_names, lr_scheduler=lr_scheduler, with_pbars=False, clear_cuda_cache=False, ) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED ) unsupervised_train_loader_iter = None if unsup_train_loader is not None: unsupervised_train_loader_iter = cycle(unsup_train_loader) cta_probe_loader_iter = None if cta_probe_loader is not None: cta_probe_loader_iter = cycle(cta_probe_loader) # Setup handler to prepare data batches @trainer.on(Events.ITERATION_STARTED) def prepare_batch(e): sup_batch = e.state.batch e.state.batch = { "sup_batch": sup_batch, } if unsupervised_train_loader_iter is not None: unsup_batch = next(unsupervised_train_loader_iter) e.state.batch["unsup_batch"] = unsup_batch if cta_probe_loader_iter is not None: cta_probe_batch = next(cta_probe_loader_iter) cta_probe_batch["policy"] = [ deserialize(p) for p in cta_probe_batch["policy"] ] e.state.batch["cta_probe_batch"] = cta_probe_batch # Setup handler to update EMA model @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay) def update_ema_model(ema_decay): # EMA on parametes for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay) # Setup handlers for debugging if cfg.debug: @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100)) @idist.one_rank_only() def log_weights_norms(): wn = [] ema_wn = [] for ema_param, param in zip(ema_model.parameters(), model.parameters()): wn.append(torch.mean(param.data)) ema_wn.append(torch.mean(ema_param.data)) msg = "\n\nWeights norms" msg += "\n- Raw model: {}".format( to_list_str(torch.tensor(wn[:10] + wn[-10:])) ) msg += "\n- EMA model: {}\n".format( to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:])) ) logger.info(msg) rmn = [] rvar = [] ema_rmn = [] ema_rvar = [] for m1, m2 in zip(model.modules(), ema_model.modules()): if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): rmn.append(torch.mean(m1.running_mean)) rvar.append(torch.mean(m1.running_var)) ema_rmn.append(torch.mean(m2.running_mean)) ema_rvar.append(torch.mean(m2.running_var)) msg = "\n\nBN buffers" msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10]))) msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10]))) msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10]))) msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10]))) logger.info(msg) # TODO: Need to inspect a bug # if idist.get_rank() == 0: # from ignite.contrib.handlers import ProgressBar # # profiler = BasicTimeProfiler() # profiler.attach(trainer) # # @trainer.on(Events.ITERATION_COMPLETED(every=200)) # def log_profiling(_): # results = profiler.get_results() # profiler.print_results(results) # Setup validation engine metrics = { "accuracy": Accuracy(), } if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU): metrics.update({ "precision": Precision(average=False), "recall": Recall(average=False), }) eval_kwargs = dict( metrics=metrics, prepare_batch=sup_prepare_batch, device=idist.device(), non_blocking=True, ) evaluator = create_supervised_evaluator(model, **eval_kwargs) ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs) def log_results(epoch, max_epochs, metrics, ema_metrics): msg1 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()] ) msg2 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()] ) logger.info( "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2) ) if cta is not None: logger.info("\n" + stats(cta)) @trainer.on( Events.EPOCH_COMPLETED(every=cfg.solver.validate_every) | Events.STARTED | Events.COMPLETED ) def run_evaluation(): evaluator.run(test_loader) ema_evaluator.run(test_loader) log_results( trainer.state.epoch, trainer.state.max_epochs, evaluator.state.metrics, ema_evaluator.state.metrics, ) # setup TB logging if idist.get_rank() == 0: tb_logger = common.setup_tb_logging( output_path, trainer, optimizers=optimizer, evaluators={"validation": evaluator, "ema validation": ema_evaluator}, log_every_iters=15, ) if cfg.online_exp_tracking.wandb: from ignite.contrib.handlers import WandBLogger wb_dir = Path("/tmp/output-fixmatch-wandb") if not wb_dir.exists(): wb_dir.mkdir() _ = WandBLogger( project="fixmatch-pytorch", name=cfg.name, config=cfg, sync_tensorboard=True, dir=wb_dir.as_posix(), reinit=True, ) resume_from = cfg.solver.resume_from if resume_from is not None: resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*")) if len(resume_from) > 0: # get latest checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) @trainer.on(Events.COMPLETED) def release_all_resources(): nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter if idist.get_rank() == 0: tb_logger.close() if unsupervised_train_loader_iter is not None: unsupervised_train_loader_iter = None if cta_probe_loader_iter is not None: cta_probe_loader_iter = None return trainer
def test_idist_methods_no_dist(): assert idist.get_world_size() < 2 assert idist.backend() is None, "{}".format(idist.backend())
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml): device = config.device prepare_batch = data.prepare_image_mask # Setup trainer accumulation_steps = config.get("accumulation_steps", 1) model_output_transform = config.get("model_output_transform", lambda x: x) with_amp = config.get("with_amp", True) scaler = GradScaler(enabled=with_amp) def forward_pass(batch): model.train() x, y = prepare_batch(batch, device=device, non_blocking=True) with autocast(enabled=with_amp): y_pred = model(x) y_pred = model_output_transform(y_pred) loss = criterion(y_pred, y) / accumulation_steps return loss def amp_backward_pass(engine, loss): scaler.scale(loss).backward() if engine.state.iteration % accumulation_steps == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() def hvd_amp_backward_pass(engine, loss): scaler.scale(loss).backward() optimizer.synchronize() with optimizer.skip_synchronize(): scaler.step(optimizer) scaler.update() optimizer.zero_grad() if idist.backend() == "horovod" and with_amp: backward_pass = hvd_amp_backward_pass else: backward_pass = amp_backward_pass def training_step(engine, batch): loss = forward_pass(batch) output = {"supervised batch loss": loss.item()} backward_pass(engine, loss) return output trainer = Engine(training_step) trainer.logger = logger output_names = [ "supervised batch loss", ] lr_scheduler = config.lr_scheduler to_save = { "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "trainer": trainer, "amp": scaler, } save_every_iters = config.get("save_every_iters", 1000) common.setup_common_training_handlers( trainer, train_sampler, to_save=to_save, save_every_iters=save_every_iters, save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml), lr_scheduler=lr_scheduler, output_names=output_names, with_pbars=not with_clearml, log_every_iters=1, ) resume_from = config.get("resume_from", None) if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def training(rank, config): # Specific ignite.distributed print( idist.get_rank(), ": run with config:", config, "- backend=", idist.backend(), "- world size", idist.get_world_size(), ) device = idist.device() # Data preparation: dataset = RndDataset(nb_samples=config["nb_samples"]) # Specific ignite.distributed train_loader = idist.auto_dataloader(dataset, batch_size=config["batch_size"]) # Model, criterion, optimizer setup model = idist.auto_model(wide_resnet50_2(num_classes=100)) criterion = NLLLoss() optimizer = idist.auto_optim(SGD(model.parameters(), lr=0.01)) # Training loop log param log_interval = config["log_interval"] def _train_step(engine, batch): data = batch[0].to(device) target = batch[1].to(device) optimizer.zero_grad() output = model(data) # Add a softmax layer probabilities = torch.nn.functional.softmax(output, dim=0) loss_val = criterion(probabilities, target) loss_val.backward() optimizer.step() return loss_val # Running the _train_step function on whole batch_data iterable only once trainer = Engine(_train_step) # Add a logger @trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) def log_training(): print("Process {}/{} Train Epoch: {} [{}/{}]\tLoss: {}".format( idist.get_rank(), idist.get_world_size(), trainer.state.epoch, trainer.state.iteration * len(trainer.state.batch[0]), len(dataset) / idist.get_world_size(), trainer.state.output, )) trainer.run(train_loader, max_epochs=1)
def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="CIFAR10-QAT-Training", distributed_rank=local_rank) log_basic_info(logger, config) output_path = config["output_path"] if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") folder_name = "{}_backend-{}-{}_{}".format(config["model"], idist.backend(), idist.get_world_size(), now) output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_path"] = output_path.as_posix() logger.info("Output path: {}".format(config["output_path"])) if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_dataflow(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) # Create trainer for current task trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "Accuracy": Accuracy(), "Loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) # Store 3 best models by validation accuracy: common.save_best_model_by_val_score( output_path=config["output_path"], evaluator=evaluator, model=model, metric_name="Accuracy", n_saved=1, trainer=trainer, tag="test", ) trainer.run(train_loader, max_epochs=config["num_epochs"]) if rank == 0: tb_logger.close()
def test_idist_methods_no_dist(): assert idist.get_world_size() < 2 assert idist.backend() is None, f"{idist.backend()}"
def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="ImageNet-Training", distributed_rank=local_rank) log_basic_info(logger, config) output_path = config["output_path"] if rank == 0: if config["stop_iteration"] is None: now = datetime.now().strftime("%Y%m%d-%H%M%S") else: now = "stop-on-{}".format(config["stop_iteration"]) folder_name = "{}_backend-{}-{}_{}".format(config["model"], idist.backend(), idist.get_world_size(), now) output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_path"] = output_path.as_posix() logger.info("Output path: {}".format(config["output_path"])) if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_imagenet_dataloader(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) # Create trainer for current task trainer = create_supervised_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "accuracy": Accuracy(), "loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) # Store 3 best models by validation accuracy: common.gen_save_best_models_by_val_score( save_handler=get_save_handler(config), evaluator=evaluator, models={"model": model}, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test", ) # In order to check training resuming we can stop training on a given iteration if config["stop_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"])) def _(): logger.info("Stop training on {} iteration".format( trainer.state.iteration)) trainer.terminate() @trainer.on(Events.ITERATION_COMPLETED(every=20)) def print_acc(engine): if rank == 0: print("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}"\ .format(engine.state.epoch, engine.state.iteration, len(train_loader), engine.state.saved_batch_loss )) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: import traceback print(traceback.format_exc()) if rank == 0: tb_logger.close()