def run( cls, model: AbsESPnetModel, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, plot_attention_iter_factory: Optional[AbsIterFactory], trainer_options, distributed_option: DistributedOption, ) -> None: """Perform training. This method performs the main process of training.""" assert check_argument_types() # NOTE(kamo): Don't check the type more strictly as far trainer_options assert is_dataclass(trainer_options), type(trainer_options) assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) if isinstance(trainer_options.keep_nbest_models, int): keep_nbest_models = [trainer_options.keep_nbest_models] else: if len(trainer_options.keep_nbest_models) == 0: logging.warning("No keep_nbest_models is given. Change to [1]") trainer_options.keep_nbest_models = [1] keep_nbest_models = trainer_options.keep_nbest_models output_dir = Path(trainer_options.output_dir) reporter = Reporter() if trainer_options.use_amp: if LooseVersion(torch.__version__) < LooseVersion("1.6.0"): raise RuntimeError( "Require torch>=1.6.0 for Automatic Mixed Precision") if trainer_options.sharded_ddp: if fairscale is None: raise RuntimeError( "Requiring fairscale. Do 'pip install fairscale'") scaler = fairscale.optim.grad_scaler.ShardedGradScaler() else: scaler = GradScaler() else: scaler = None if trainer_options.resume and (output_dir / "checkpoint.pth").exists(): cls.resume( checkpoint=output_dir / "checkpoint.pth", model=model, optimizers=optimizers, schedulers=schedulers, reporter=reporter, scaler=scaler, ngpu=trainer_options.ngpu, ) start_epoch = reporter.get_epoch() + 1 if start_epoch == trainer_options.max_epoch + 1: logging.warning( f"The training has already reached at max_epoch: {start_epoch}" ) if distributed_option.distributed: if trainer_options.sharded_ddp: dp_model = fairscale.nn.data_parallel.ShardedDataParallel( module=model, sharded_optimizer=optimizers, ) else: dp_model = torch.nn.parallel.DistributedDataParallel( model, device_ids=( # Perform multi-Process with multi-GPUs [torch.cuda.current_device()] if distributed_option.ngpu == 1 # Perform single-Process with multi-GPUs else None), output_device=(torch.cuda.current_device() if distributed_option.ngpu == 1 else None), find_unused_parameters=trainer_options.unused_parameters, ) elif distributed_option.ngpu > 1: dp_model = torch.nn.parallel.DataParallel( model, device_ids=list(range(distributed_option.ngpu)), ) else: # NOTE(kamo): DataParallel also should work with ngpu=1, # but for debuggability it's better to keep this block. dp_model = model if trainer_options.use_tensorboard and ( not distributed_option.distributed or distributed_option.dist_rank == 0): from torch.utils.tensorboard import SummaryWriter train_summary_writer = SummaryWriter( str(output_dir / "tensorboard" / "train")) valid_summary_writer = SummaryWriter( str(output_dir / "tensorboard" / "valid")) else: train_summary_writer = None start_time = time.perf_counter() for iepoch in range(start_epoch, trainer_options.max_epoch + 1): if iepoch != start_epoch: logging.info( "{}/{}epoch started. Estimated time to finish: {}".format( iepoch, trainer_options.max_epoch, humanfriendly.format_timespan( (time.perf_counter() - start_time) / (iepoch - start_epoch) * (trainer_options.max_epoch - iepoch + 1)), )) else: logging.info( f"{iepoch}/{trainer_options.max_epoch}epoch started") set_all_random_seed(trainer_options.seed + iepoch) reporter.set_epoch(iepoch) # 1. Train and validation for one-epoch with reporter.observe("train") as sub_reporter: all_steps_are_invalid = cls.train_one_epoch( model=dp_model, optimizers=optimizers, schedulers=schedulers, iterator=train_iter_factory.build_iter(iepoch), reporter=sub_reporter, scaler=scaler, summary_writer=train_summary_writer, options=trainer_options, distributed_option=distributed_option, ) with reporter.observe("valid") as sub_reporter: cls.validate_one_epoch( model=dp_model, iterator=valid_iter_factory.build_iter(iepoch), reporter=sub_reporter, options=trainer_options, distributed_option=distributed_option, ) if not distributed_option.distributed or distributed_option.dist_rank == 0: # att_plot doesn't support distributed if plot_attention_iter_factory is not None: with reporter.observe("att_plot") as sub_reporter: cls.plot_attention( model=model, output_dir=output_dir / "att_ws", summary_writer=train_summary_writer, iterator=plot_attention_iter_factory.build_iter( iepoch), reporter=sub_reporter, options=trainer_options, ) # 2. LR Scheduler step for scheduler in schedulers: if isinstance(scheduler, AbsValEpochStepScheduler): scheduler.step( reporter.get_value( *trainer_options.val_scheduler_criterion)) elif isinstance(scheduler, AbsEpochStepScheduler): scheduler.step() if trainer_options.sharded_ddp: for optimizer in optimizers: if isinstance(optimizer, fairscale.optim.oss.OSS): optimizer.consolidate_state_dict() if not distributed_option.distributed or distributed_option.dist_rank == 0: # 3. Report the results logging.info(reporter.log_message()) if trainer_options.use_matplotlib: reporter.matplotlib_plot(output_dir / "images") if train_summary_writer is not None: reporter.tensorboard_add_scalar(train_summary_writer, key1="train") reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid") if trainer_options.use_wandb: reporter.wandb_log() # 4. Save/Update the checkpoint torch.save( { "model": model.state_dict(), "reporter": reporter.state_dict(), "optimizers": [o.state_dict() for o in optimizers], "schedulers": [ s.state_dict() if s is not None else None for s in schedulers ], "scaler": scaler.state_dict() if scaler is not None else None, }, output_dir / "checkpoint.pth", ) # 5. Save and log the model and update the link to the best model torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") # Creates a sym link latest.pth -> {iepoch}epoch.pth p = output_dir / "latest.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved = [] for _phase, k, _mode in trainer_options.best_model_criterion: # e.g. _phase, k, _mode = "train", "loss", "min" if reporter.has(_phase, k): best_epoch = reporter.get_best_epoch(_phase, k, _mode) # Creates sym links if it's the best result if best_epoch == iepoch: p = output_dir / f"{_phase}.{k}.best.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved.append(f"{_phase}.{k}") if len(_improved) == 0: logging.info("There are no improvements in this epoch") else: logging.info("The best model has been updated: " + ", ".join(_improved)) log_model = (trainer_options.wandb_model_log_interval > 0 and iepoch % trainer_options.wandb_model_log_interval == 0) if log_model and trainer_options.use_wandb: import wandb logging.info("Logging Model on this epoch :::::") artifact = wandb.Artifact( name=f"model_{wandb.run.id}", type="model", metadata={"improved": _improved}, ) artifact.add_file(str(output_dir / f"{iepoch}epoch.pth")) aliases = [ f"epoch-{iepoch}", "best" if best_epoch == iepoch else "", ] wandb.log_artifact(artifact, aliases=aliases) # 6. Remove the model files excluding n-best epoch and latest epoch _removed = [] # Get the union set of the n-best among multiple criterion nbests = set().union(*[ set( reporter.sort_epochs(ph, k, m) [:max(keep_nbest_models)]) for ph, k, m in trainer_options.best_model_criterion if reporter.has(ph, k) ]) # Generated n-best averaged model if (trainer_options.nbest_averaging_interval > 0 and iepoch % trainer_options.nbest_averaging_interval == 0): average_nbest_models( reporter=reporter, output_dir=output_dir, best_model_criterion=trainer_options. best_model_criterion, nbest=keep_nbest_models, suffix=f"till{iepoch}epoch", ) for e in range(1, iepoch): p = output_dir / f"{e}epoch.pth" if p.exists() and e not in nbests: p.unlink() _removed.append(str(p)) if len(_removed) != 0: logging.info("The model files were removed: " + ", ".join(_removed)) # 7. If any updating haven't happened, stops the training if all_steps_are_invalid: logging.warning( f"The gradients at all steps are invalid in this epoch. " f"Something seems wrong. This training was stopped at {iepoch}epoch" ) break # 8. Check early stopping if trainer_options.patience is not None: if reporter.check_early_stopping( trainer_options.patience, *trainer_options.early_stopping_criterion): break else: logging.info( f"The training was finished at {trainer_options.max_epoch} epochs " ) # Generated n-best averaged model if not distributed_option.distributed or distributed_option.dist_rank == 0: average_nbest_models( reporter=reporter, output_dir=output_dir, best_model_criterion=trainer_options.best_model_criterion, nbest=keep_nbest_models, )
def run( cls, model: AbsESPnetModel, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, plot_attention_iter_factory: Optional[AbsIterFactory], reporter: Reporter, scaler: Optional[GradScaler], output_dir: Path, max_epoch: int, seed: int, patience: Optional[int], keep_nbest_models: int, early_stopping_criterion: Sequence[str], best_model_criterion: Sequence[Sequence[str]], val_scheduler_criterion: Sequence[str], trainer_options, distributed_option: DistributedOption, ) -> None: """Perform training. This method performs the main process of training.""" assert check_argument_types() # NOTE(kamo): Don't check the type more strictly as far trainer_options assert is_dataclass(trainer_options), type(trainer_options) start_epoch = reporter.get_epoch() + 1 if start_epoch == max_epoch + 1: logging.warning( f"The training has already reached at max_epoch: {start_epoch}" ) if distributed_option.distributed: dp_model = torch.nn.parallel.DistributedDataParallel( model, device_ids=( # Perform multi-Process with multi-GPUs [torch.cuda.current_device()] if distributed_option.ngpu == 1 # Perform single-Process with multi-GPUs else None), output_device=(torch.cuda.current_device() if distributed_option.ngpu == 1 else None), ) elif distributed_option.ngpu > 1: dp_model = torch.nn.parallel.DataParallel( model, device_ids=list(range(distributed_option.ngpu)), ) else: # NOTE(kamo): DataParallel also should work with ngpu=1, # but for debuggability it's better to keep this block. dp_model = model if not distributed_option.distributed or distributed_option.dist_rank == 0: summary_writer = SummaryWriter(str(output_dir / "tensorboard")) else: summary_writer = None start_time = time.perf_counter() for iepoch in range(start_epoch, max_epoch + 1): if iepoch != start_epoch: logging.info( "{}/{}epoch started. Estimated time to finish: {}".format( iepoch, max_epoch, humanfriendly.format_timespan( (time.perf_counter() - start_time) / (iepoch - start_epoch) * (max_epoch - iepoch + 1)), )) else: logging.info(f"{iepoch}/{max_epoch}epoch started") set_all_random_seed(seed + iepoch) reporter.set_epoch(iepoch) # 1. Train and validation for one-epoch with reporter.observe("train") as sub_reporter: all_steps_are_invalid = cls.train_one_epoch( model=dp_model, optimizers=optimizers, schedulers=schedulers, iterator=train_iter_factory.build_iter(iepoch), reporter=sub_reporter, scaler=scaler, summary_writer=summary_writer, options=trainer_options, ) with reporter.observe("valid") as sub_reporter: cls.validate_one_epoch( model=dp_model, iterator=valid_iter_factory.build_iter(iepoch), reporter=sub_reporter, options=trainer_options, ) if not distributed_option.distributed or distributed_option.dist_rank == 0: # att_plot doesn't support distributed if plot_attention_iter_factory is not None: with reporter.observe("att_plot") as sub_reporter: cls.plot_attention( model=model, output_dir=output_dir / "att_ws", summary_writer=summary_writer, iterator=plot_attention_iter_factory.build_iter( iepoch), reporter=sub_reporter, options=trainer_options, ) # 2. LR Scheduler step for scheduler in schedulers: if isinstance(scheduler, AbsValEpochStepScheduler): scheduler.step( reporter.get_value(*val_scheduler_criterion)) elif isinstance(scheduler, AbsEpochStepScheduler): scheduler.step() if not distributed_option.distributed or distributed_option.dist_rank == 0: # 3. Report the results logging.info(reporter.log_message()) reporter.matplotlib_plot(output_dir / "images") reporter.tensorboard_add_scalar(summary_writer) # 4. Save/Update the checkpoint torch.save( { "model": model.state_dict(), "reporter": reporter.state_dict(), "optimizers": [o.state_dict() for o in optimizers], "schedulers": [ s.state_dict() if s is not None else None for s in schedulers ], "scaler": scaler.state_dict() if scaler is not None else None, }, output_dir / "checkpoint.pth", ) # 5. Save the model and update the link to the best model torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") # Creates a sym link latest.pth -> {iepoch}epoch.pth p = output_dir / "latest.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved = [] for _phase, k, _mode in best_model_criterion: # e.g. _phase, k, _mode = "train", "loss", "min" if reporter.has(_phase, k): best_epoch = reporter.get_best_epoch(_phase, k, _mode) # Creates sym links if it's the best result if best_epoch == iepoch: p = output_dir / f"{_phase}.{k}.best.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved.append(f"{_phase}.{k}") if len(_improved) == 0: logging.info("There are no improvements in this epoch") else: logging.info("The best model has been updated: " + ", ".join(_improved)) # 6. Remove the model files excluding n-best epoch and latest epoch _removed = [] # Get the union set of the n-best among multiple criterion nbests = set().union(*[ set(reporter.sort_epochs(ph, k, m)[:keep_nbest_models]) for ph, k, m in best_model_criterion if reporter.has(ph, k) ]) for e in range(1, iepoch): p = output_dir / f"{e}epoch.pth" if p.exists() and e not in nbests: p.unlink() _removed.append(str(p)) if len(_removed) != 0: logging.info("The model files were removed: " + ", ".join(_removed)) # 7. If any updating haven't happened, stops the training if all_steps_are_invalid: logging.warning( f"The gradients at all steps are invalid in this epoch. " f"Something seems wrong. This training was stopped at {iepoch}epoch" ) break # 8. Check early stopping if patience is not None: if reporter.check_early_stopping(patience, *early_stopping_criterion): break else: logging.info(f"The training was finished at {max_epoch} epochs ")