def gen_save_best_models_by_val_score( save_handler: Union[Callable, BaseSaveHandler], evaluator: Engine, models: Union[torch.nn.Module, Dict[str, torch.nn.Module]], metric_name: str, n_saved: int = 3, trainer: Optional[Engine] = None, tag: str = "val", **kwargs: Any, ) -> Checkpoint: """Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric (named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``). Models with highest metric value will be retained. The logic of how to store objects is delegated to ``save_handler``. Args: save_handler: Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If ``save_handler`` is callable class, it can inherit of :class:`~ignite.handlers.checkpoint.BaseSaveHandler` and optionally implement ``remove`` method to keep a fixed number of saved checkpoints. In case if user needs to save engine's checkpoint on a disk, ``save_handler`` can be defined with :class:`~ignite.handlers.DiskSaver`. evaluator: evaluation engine used to provide the score models: model or dictionary with the object to save. Objects should have implemented ``state_dict`` and ``load_state_dict`` methods. metric_name: metric name to use for score evaluation. This metric should be present in `evaluator.state.metrics`. n_saved: number of best models to store trainer: trainer engine to fetch the epoch when saving the best model. tag: score name prefix: `{tag}_{metric_name}`. By default, tag is "val". kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.Checkpoint`. Returns: A :class:`~ignite.handlers.Checkpoint` handler. """ global_step_transform = None if trainer is not None: global_step_transform = global_step_from_engine(trainer) if isinstance(models, nn.Module): to_save = {"model": models} # type: Dict[str, nn.Module] else: to_save = models best_model_handler = Checkpoint( to_save, save_handler, filename_prefix="best", n_saved=n_saved, global_step_transform=global_step_transform, score_name=f"{tag}_{metric_name.lower()}", score_function=Checkpoint.get_default_score_fn(metric_name), **kwargs, ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler) return best_model_handler
def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="CIFAR10-Training", distributed_rank=local_rank) log_basic_info(logger, config) output_path = config["output_path"] if rank == 0: if config["stop_iteration"] is None: now = datetime.now().strftime("%Y%m%d-%H%M%S") else: now = f"stop-on-{config['stop_iteration']}" folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}" output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_path"] = output_path.as_posix() logger.info(f"Output path: {config['output_path']}") if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) if config["with_clearml"]: try: from clearml import Task except ImportError: # Backwards-compatibility for legacy Trains SDK from trains import Task task = Task.init("CIFAR10-Training", task_name=output_path.stem) task.connect_configuration(config) # Log hyper parameters hyper_params = [ "model", "batch_size", "momentum", "weight_decay", "num_epochs", "learning_rate", "num_warmup_epochs", ] task.connect({k: config[k] for k in hyper_params}) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_dataflow(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) # Create trainer for current task trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "Accuracy": Accuracy(), "Loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_evaluator(model, metrics=metrics, config=config) train_evaluator = create_evaluator(model, metrics=metrics, config=config) def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) # Store 2 best models by validation accuracy starting from num_epochs / 2: best_model_handler = Checkpoint( {"model": model}, get_save_handler(config), filename_prefix="best", n_saved=2, global_step_transform=global_step_from_engine(trainer), score_name="test_accuracy", score_function=Checkpoint.get_default_score_fn("Accuracy"), ) evaluator.add_event_handler( Events.COMPLETED(lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler ) # In order to check training resuming we can stop training on a given iteration if config["stop_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"])) def _(): logger.info(f"Stop training on {trainer.state.iteration} iteration") trainer.terminate() try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: logger.exception("") raise e if rank == 0: tb_logger.close()
def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="CIFAR10-Training", distributed_rank=local_rank) log_basic_info(logger, config) output_path = config["output_path"] if rank == 0: now = datetime.now().strftime("%Y%m%d-%H%M%S") folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}" output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_path"] = output_path.as_posix() logger.info(f"Output path: {config['output_path']}") if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) if config["with_clearml"]: try: from clearml import Task except ImportError: # Backwards-compatibility for legacy Trains SDK from trains import Task task = Task.init("CIFAR10-Training", task_name=output_path.stem) task.connect_configuration(config) task.connect(config) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_dataflow(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) logger.info( f"# model parameters (M): {sum([m.numel() for m in model.parameters()]) * 1e-6}" ) # Create trainer for current task trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and # compute metrics metrics = { "Accuracy": Accuracy(), "Loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_evaluator(model, metrics=metrics, config=config) train_evaluator = create_evaluator(model, metrics=metrics, config=config) if config["smoke_test"]: logger.info( "Reduce the size of training and test dataloader as smoke_test=True" ) def get_batches(loader): loader_iter = iter(loader) return [next(loader_iter) for _ in range(5)] train_loader = get_batches(train_loader) test_loader = get_batches(test_loader) if config["with_pbar"] and rank == 0: ProgressBar(desc="Evaluation (train)", persist=False).attach(train_evaluator) ProgressBar(desc="Evaluation (val)", persist=False).attach(evaluator) def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation, ) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) # Store 1 best models by validation accuracy starting from num_epochs / 2: best_model_handler = Checkpoint( {"model": model}, get_save_handler(config), filename_prefix="best", n_saved=1, global_step_transform=global_step_from_engine(trainer), score_name="test_accuracy", score_function=Checkpoint.get_default_score_fn("Accuracy"), ) evaluator.add_event_handler( Events.COMPLETED( lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler, ) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: logger.exception("") raise e if rank == 0: tb_logger.close()
to_save, save_handler, filename_prefix="training") # Attach the handler to the trainer trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler) # Store best model by validation accuracy best_model_handler = Checkpoint( {"model": model}, save_handler, filename_prefix="best", n_saved=1, score_name="accuracy", score_function=Checkpoint.get_default_score_fn("accuracy"), ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler) # ### Setting up TensorBoard as an experiment tracking system tb_logger = TensorboardLogger(log_dir=output_path) # Attach handler to plot trainer's loss every 100 iterations tb_logger.attach_output_handler(