def test_get_handlers(tmp_path): trainer = Engine(lambda e, b: b) config = Namespace( output_dir=tmp_path, save_every_iters=1, n_saved=2, log_every_iters=1, with_pbars=False, with_pbar_on_iters=False, stop_on_nan=False, clear_cuda_cache=False, with_gpu_stats=False, patience=1, limit_sec=30, ) bm_handler, es_handler, timer_handler = get_handlers( config=config, model=nn.Linear(1, 1), trainer=trainer, evaluator=trainer, metric_name="eval_loss", es_metric_name="eval_loss", ) assert isinstance(bm_handler, (type(None), Checkpoint)), "Should be Checkpoint or None" assert isinstance( es_handler, (type(None), EarlyStopping)), "Should be EarlyStopping or None" assert isinstance(timer_handler, (type(None), Timer)), "Shoulde be Timer or None"
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" # ---------------------- # make a certain seed # ---------------------- rank = idist.get_rank() manual_seed(config.seed + rank) # ----------------------- # create output folder # ----------------------- if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") name = f"{config.model}-backend-{idist.backend()}-{now}" path = Path(config.output_dir, name) path.mkdir(parents=True, exist_ok=True) config.output_dir = path.as_posix() config.output_dir = Path(idist.broadcast(config.output_dir, src=0)) # ----------------------------- # datasets and dataloaders # ----------------------------- # TODO : PLEASE provide your custom datasets and dataloaders configurations # we can use `idist.auto_dataloader` to handle distributed configurations # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader train_dataset, eval_dataset = get_datasets(path=config.data_path) train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.train_batch_size, num_workers=config.num_workers, shuffle=True, {% if use_distributed_training and not use_distributed_launcher %} persistent_workers=True, {% endif %} ) eval_dataloader = idist.auto_dataloader( eval_dataset, batch_size=config.eval_batch_size, num_workers=config.num_workers, shuffle=False, {% if use_distributed_training and not use_distributed_launcher %} persistent_workers=True, {% endif %} ) # ------------------------------------------ # model, optimizer, loss function, device # ------------------------------------------ device = idist.device() config.num_iters_per_epoch = len(train_dataloader) model, optimizer, loss_fn, lr_scheduler = initialize(config=config) # ----------------------------- # trainer and evaluator # ----------------------------- trainer, evaluator = create_trainers( config=config, model=model, optimizer=optimizer, loss_fn=loss_fn, device=device, ) # --------------------------------- # attach metrics to evaluator # --------------------------------- accuracy = Accuracy(device=device) metrics = { "eval_accuracy": accuracy, "eval_loss": Loss(loss_fn, device=device), "eval_error": (1.0 - accuracy) * 100, } for name, metric in metrics.items(): metric.attach(evaluator, name) # ------------------------------------------- # setup engines logger with python logging # print training configurations # ------------------------------------------- logger = setup_logging(config) log_basic_info(logger, config) trainer.logger = logger evaluator.logger = logger # ------------------------------------- # ignite handlers and ignite loggers # ------------------------------------- to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler} best_model_handler, es_handler, timer_handler = get_handlers( config=config, model=model, trainer=trainer, evaluator=evaluator, metric_name="eval_accuracy", es_metric_name="eval_accuracy", to_save=to_save, lr_scheduler=lr_scheduler, output_names=None, ) # setup ignite logger only on rank 0 if rank == 0: logger_handler = get_logger( config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer ) # ----------------------------------- # resume from the saved checkpoints # ----------------------------------- if config.resume_from: resume_from(to_load=to_save, checkpoint_fp=config.resume_from) # -------------------------------- # print metrics to the stderr # with `add_event_handler` API # for training stats # -------------------------------- trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train") # --------------------------------------------- # run evaluation at every training epoch end # with shortcut `on` decorator API and # print metrics to the stderr # again with `add_event_handler` API # for evaluation stats # --------------------------------------------- @trainer.on(Events.EPOCH_COMPLETED(every=1)) def _(): evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length) log_metrics(evaluator, "eval") # -------------------------------------------------- # let's try run evaluation first as a sanity check # -------------------------------------------------- @trainer.on(Events.STARTED) def _(): evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length) # ------------------------------------------ # setup if done. let's run the training # ------------------------------------------ trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length) # ------------------------------------------------------------ # close the logger after the training completed / terminated # ------------------------------------------------------------ if rank == 0: from ignite.contrib.handlers.wandb_logger import WandBLogger if isinstance(logger_handler, WandBLogger): # why handle differently for wandb ? # See : https://github.com/pytorch/ignite/issues/1894 logger_handler.finish() elif logger_handler: logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? # ----------------------------------------- if best_model_handler is not None: logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" # ---------------------- # make a certain seed # ---------------------- rank = idist.get_rank() manual_seed(config.seed + rank) # ----------------------- # create output folder # ----------------------- if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") name = f"{config.dataset}-backend-{idist.backend()}-{now}" path = Path(config.output_dir, name) path.mkdir(parents=True, exist_ok=True) config.output_dir = path.as_posix() config.output_dir = Path(idist.broadcast(config.output_dir, src=0)) # ----------------------------- # datasets and dataloaders # ----------------------------- train_dataset, num_channels = get_datasets(config.dataset, config.data_path) train_dataloader = idist.auto_dataloader( train_dataset, batch_size=config.batch_size, num_workers=config.num_workers, {% if use_distributed_training and not use_distributed_launcher %} persistent_workers=True, {% endif %} ) # ------------------------------------------ # model, optimizer, loss function, device # ------------------------------------------ device = idist.device() netD, netG, optimizerD, optimizerG, loss_fn, lr_scheduler = initialize(config, num_channels) # ----------------------------- # trainer and evaluator # ----------------------------- ws = idist.get_world_size() real_labels = torch.ones(config.batch_size // ws, device=device) fake_labels = torch.zeros(config.batch_size // ws, device=device) fixed_noise = torch.randn(config.batch_size // ws, config.z_dim, 1, 1, device=device) trainer = create_trainers( config=config, netD=netD, netG=netG, optimizerD=optimizerD, optimizerG=optimizerG, loss_fn=loss_fn, device=device, real_labels=real_labels, fake_labels=fake_labels, ) # ------------------------------------------- # setup engines logger with python logging # print training configurations # ------------------------------------------- logger = setup_logging(config) log_basic_info(logger, config) trainer.logger = logger # ------------------------------------- # ignite handlers and ignite loggers # ------------------------------------- to_save = {'netD': netD, 'netG': netG, 'optimizerD': optimizerD, 'optimizerG': optimizerG, 'trainer': trainer} optimizers = {'optimizerD': optimizerD, 'optimizerG': optimizerG} best_model_handler, es_handler, timer_handler = get_handlers( config=config, model={'netD', netD, 'netG', netG}, trainer=trainer, evaluator=trainer, metric_name='errD', es_metric_name='errD', to_save=to_save, lr_scheduler=lr_scheduler, output_names=["errD", "errG", "D_x", "D_G_z1", "D_G_z2"], ) # setup ignite logger only on rank 0 if rank == 0: logger_handler = get_logger(config=config, trainer=trainer, optimizers=optimizers) # ----------------------------------- # resume from the saved checkpoints # ----------------------------------- if config.resume_from: resume_from(to_load=to_save, checkpoint_fp=config.resume_from) # -------------------------------------------------- # adding handlers using `trainer.on` decorator API # -------------------------------------------------- @trainer.on(Events.EPOCH_COMPLETED) def save_fake_example(engine): fake = netG(fixed_noise) path = config.output_dir / (FAKE_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(fake.detach(), path, normalize=True) # -------------------------------------------------- # adding handlers using `trainer.on` decorator API # -------------------------------------------------- @trainer.on(Events.EPOCH_COMPLETED) def save_real_example(engine): img, y = engine.state.batch path = config.output_dir / (REAL_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(img, path, normalize=True) # ------------------------------------------------------------- # adding handlers using `trainer.on` decorator API # ------------------------------------------------------------- @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): if not timer_handler: logger.info(f"Epoch {engine.state.epoch} done. Time per batch: {timer_handler.value():.3f}[s]") timer_handler.reset() @trainer.on(Events.ITERATION_COMPLETED(every=config.log_every_iters)) @idist.one_rank_only() def print_logs(engine): fname = config.output_dir / LOGS_FNAME columns = ["iteration", ] + list(engine.state.metrics.keys()) values = [str(engine.state.iteration), ] + [str(round(value, 5)) for value in engine.state.metrics.values()] with open(fname, "a") as f: if f.tell() == 0: print("\t".join(columns), file=f) print("\t".join(values), file=f) message = f"[{engine.state.epoch}/{config.max_epochs}][{engine.state.iteration % len(train_dataloader)}/{len(train_dataloader)}]" for name, value in zip(columns, values): message += f" | {name}: {value}" # ------------------------------------------------------------- # adding handlers using `trainer.on` decorator API # ------------------------------------------------------------- @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): try: import matplotlib as mpl mpl.use("agg") import matplotlib.pyplot as plt import pandas as pd except ImportError: warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found") else: df = pd.read_csv(config.output_dir / LOGS_FNAME, delimiter="\t", index_col="iteration") _ = df.plot(subplots=True, figsize=(20, 20)) _ = plt.xlabel("Iteration number") fig = plt.gcf() path = config.output_dir / PLOT_FNAME fig.savefig(path) # -------------------------------- # print metrics to the stderr # with `add_event_handler` API # for training stats # -------------------------------- trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train") # ------------------------------------------ # setup if done. let's run the training # ------------------------------------------ trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length) # ------------------------------------------------------------ # close the logger after the training completed / terminated # ------------------------------------------------------------ if rank == 0: from ignite.contrib.handlers.wandb_logger import WandBLogger if isinstance(logger_handler, WandBLogger): # why handle differently for wandb ? # See : https://github.com/pytorch/ignite/issues/1894 logger_handler.finish() elif logger_handler: logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? # ----------------------------------------- if best_model_handler is not None: logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
def run(local_rank: int, config: Any, *args: Any, **kwargs: Any): """function to be run by idist.Parallel context manager.""" # ---------------------- # make a certain seed # ---------------------- rank = idist.get_rank() manual_seed(config.seed + rank) # ----------------------- # create output folder # ----------------------- if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") name = f"{config.dataset}-backend-{idist.backend()}-{now}" path = Path(config.output_dir, name) path.mkdir(parents=True, exist_ok=True) config.output_dir = path.as_posix() config.output_dir = Path(idist.broadcast(config.output_dir, src=0)) # ----------------------------- # datasets and dataloaders # ----------------------------- # TODO : PLEASE provide your custom datasets and dataloaders configurations # we can use `idist.auto_dataloader` to handle distributed configurations # TODO : PLEASE replace `kwargs` with your desirable DataLoader arguments # See : https://pytorch.org/ignite/distributed.html#ignite.distributed.auto.auto_dataloader train_dataset, eval_dataset = get_datasets() train_dataloader = idist.auto_dataloader(train_dataset, **kwargs) eval_dataloader = idist.auto_dataloader(eval_dataset, **kwargs) # ------------------------------------------ # model, optimizer, loss function, device # ------------------------------------------ device = idist.device() model, optimizer, loss_fn, lr_scheduler = initialize() # ----------------------------- # trainer and evaluator # ----------------------------- trainer, evaluator = create_trainers( config=config, model=model, optimizer=optimizer, loss_fn=loss_fn, device=device, ) # ------------------------------------------- # update config with optimizer parameters # setup engines logger with python logging # print training configurations # ------------------------------------------- config.__dict__.update(**optimizer.defaults) logger = setup_logging(config) log_basic_info(logger, config) trainer.logger = logger evaluator.logger = logger # ------------------------------------- # ignite handlers and ignite loggers # ------------------------------------- to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler} best_model_handler, es_handler, timer_handler = get_handlers( config=config, model=model, trainer=trainer, evaluator=evaluator, metric_name=None, # TODO : replace with the metric name to save the best model # if you check `Save the best model by evaluation score` otherwise leave it None # metric must be in evaluator.state.metrics. es_metric_name=None, # TODO : replace with the metric name to early stop # if you check `Early stop the training by evaluation score` otherwise leave it None # metric must be in evaluator.state.metrics. to_save=to_save, lr_scheduler=lr_scheduler, output_names=None, ) # setup ignite logger only on rank 0 if rank == 0: logger_handler = get_logger( config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer ) # ----------------------------------- # resume from the saved checkpoints # ----------------------------------- if config.resume_from: resume_from(to_load=to_save, checkpoint_fp=config.resume_from) # -------------------------------------------- # let's trigger custom events we registered # we will use a `event_filter` to trigger that # `event_filter` has to return boolean # whether this event should be executed # here will log the gradients on the 1st iteration # and every 100 iterations # -------------------------------------------- @trainer.on(TrainEvents.BACKWARD_COMPLETED(lambda _, ev: (ev % 100 == 0) or (ev == 1))) def _(): # do something interesting pass # ---------------------------------------- # here we will use `every` to trigger # every 100 iterations # ---------------------------------------- @trainer.on(TrainEvents.OPTIM_STEP_COMPLETED(every=100)) def _(): # do something interesting pass # -------------------------------- # print metrics to the stderr # with `add_event_handler` API # for training stats # -------------------------------- trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train") # --------------------------------------------- # run evaluation at every training epoch end # with shortcut `on` decorator API and # print metrics to the stderr # again with `add_event_handler` API # for evaluation stats # --------------------------------------------- @trainer.on(Events.EPOCH_COMPLETED(every=1)) def _(): evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length) log_metrics(evaluator, "eval") # -------------------------------------------------- # let's try run evaluation first as a sanity check # -------------------------------------------------- @trainer.on(Events.STARTED) def _(): evaluator.run(eval_dataloader, epoch_length=config.eval_epoch_length) # ------------------------------------------ # setup if done. let's run the training # ------------------------------------------ # TODO : PLEASE provide `max_epochs` parameters trainer.run(train_dataloader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length) # ------------------------------------------------------------ # close the logger after the training completed / terminated # ------------------------------------------------------------ if rank == 0: from ignite.contrib.handlers.wandb_logger import WandBLogger if isinstance(logger_handler, WandBLogger): # why handle differently for wandb ? # See : https://github.com/pytorch/ignite/issues/1894 logger_handler.finish() elif logger_handler: logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? # ----------------------------------------- if best_model_handler is not None: logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)
def run(local_rank, config): # ---------------------- # Make a certain seed # ---------------------- rank = idist.get_rank() manual_seed(config.seed + rank) device = idist.device() # ----------------------- # Create output folder # ----------------------- if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") name = f"{config.model}-backend-{idist.backend()}-{now}" path = Path(config.output_dir, name) path.mkdir(parents=True, exist_ok=True) config.output_dir = path.as_posix() config.output_dir = Path(idist.broadcast(config.output_dir, src=0)) # ----------------------------- # datasets and dataloaders # ----------------------------- train_loader, test_loader = get_dataflow(config) # ------------------------------------------ # model, optimizer, loss function, device # ------------------------------------------ config.num_iters_per_epoch = len(train_loader) model, optimizer, loss_fn, lr_scheduler = initialize(config) # ----------------------------- # trainer and evaluator # ----------------------------- trainer, evaluator = create_trainers( config=config, model=model, optimizer=optimizer, loss_fn=loss_fn, device=device, ) # --------------------------------- # attach metrics to evaluator # --------------------------------- metrics = { "eval_accuracy": Accuracy(output_transform=thresholded_output_transform, device=device), "eval_loss": Loss(loss_fn, device=device), } for name, metric in metrics.items(): metric.attach(evaluator, name) # ------------------------------------------- # setup engines logger with python logging # print training configurations # ------------------------------------------- logger = setup_logging(config) log_basic_info(logger, config) trainer.logger = logger evaluator.logger = logger # ------------------------------------- # ignite handlers and ignite loggers # ------------------------------------- to_save = {"model": model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler} best_model_handler, es_handler, timer_handler = get_handlers( config=config, model=model, trainer=trainer, evaluator=evaluator, metric_name="eval_accuracy", es_metric_name="eval_accuracy", to_save=to_save, lr_scheduler=lr_scheduler, output_names=None, ) # setup ignite logger only on rank 0 if rank == 0: logger_handler = get_logger( config=config, trainer=trainer, evaluator=evaluator, optimizers=optimizer ) # ----------------------------------- # resume from the saved checkpoints # ----------------------------------- if config.resume_from: resume_from(to_load=to_save, checkpoint_fp=config.resume_from) # -------------------------------- # print metrics to the stderr # with `add_event_handler` API # for training stats # -------------------------------- trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.log_every_iters), log_metrics, tag="train") # --------------------------------------------- # run evaluation at every training epoch end # with shortcut `on` decorator API and # print metrics to the stderr # again with `add_event_handler` API # for evaluation stats # --------------------------------------------- @trainer.on(Events.EPOCH_COMPLETED(every=config.validate_every)) def _(): evaluator.run(test_loader, epoch_length=config.eval_epoch_length) log_metrics(evaluator, tag="eval") # -------------------------------------------------- # let's try run evaluation first as a sanity check # -------------------------------------------------- @trainer.on(Events.STARTED) def _(): evaluator.run(test_loader, epoch_length=config.eval_epoch_length) # ------------------------------------------ # setup if done. let's run the training # ------------------------------------------ trainer.run(train_loader, max_epochs=config.max_epochs, epoch_length=config.train_epoch_length) # ------------------------------------------------------------ # close the logger after the training completed / terminated # ------------------------------------------------------------ if rank == 0: from ignite.contrib.handlers.wandb_logger import WandBLogger if isinstance(logger_handler, WandBLogger): # why handle differently for wandb ? # See : https://github.com/pytorch/ignite/issues/1894 logger_handler.finish() elif logger_handler: logger_handler.close() # ----------------------------------------- # where is my best and last checkpoint ? # ----------------------------------------- if best_model_handler is not None: logger.info("Last and best checkpoint: %s", best_model_handler.last_checkpoint)