def calculate_all_attentions( model: AbsESPnetModel, batch: Dict[str, torch.Tensor] ) -> Dict[str, List[torch.Tensor]]: """Derive the outputs from the all attention layers Args: model: batch: same as forward Returns: return_dict: A dict of a list of tensor. key_names x batch x (D1, D2, ...) """ bs = len(next(iter(batch.values()))) assert all(len(v) == bs for v in batch.values()), { k: v.shape for k, v in batch.items() } # 1. Register forward_hook fn to save the output from specific layers outputs = {} handles = {} for name, modu in model.named_modules(): def hook(module, input, output, name=name): if isinstance(module, MultiHeadedAttention): # NOTE(kamo): MultiHeadedAttention doesn't return attention weight # attn: (B, Head, Tout, Tin) outputs[name] = module.attn.detach().cpu() elif isinstance(module, AttLoc2D): c, w = output # w: previous concate attentions # w: (B, nprev, Tin) att_w = w[:, -1].detach().cpu() outputs.setdefault(name, []).append(att_w) elif isinstance(module, (AttCov, AttCovLoc)): c, w = output assert isinstance(w, list), type(w) # w: list of previous attentions # w: nprev x (B, Tin) att_w = w[-1].detach().cpu() outputs.setdefault(name, []).append(att_w) elif isinstance(module, AttLocRec): # w: (B, Tin) c, (w, (att_h, att_c)) = output att_w = w.detach().cpu() outputs.setdefault(name, []).append(att_w) elif isinstance( module, ( AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc, ), ): c, w = output # w: nhead x (B, Tin) assert isinstance(w, list), type(w) att_w = [_w.detach().cpu() for _w in w] outputs.setdefault(name, []).append(att_w) elif isinstance( module, ( AttAdd, AttDot, AttForward, AttForwardTA, AttLoc, NoAtt, ), ): c, w = output att_w = w.detach().cpu() outputs.setdefault(name, []).append(att_w) handle = modu.register_forward_hook(hook) handles[name] = handle # 2. Just forward one by one sample. # Batch-mode can't be used to keep requirements small for each models. keys = [] for k in batch: if not (k.endswith("_lengths") or k in ["utt_id"]): keys.append(k) return_dict = defaultdict(list) for ibatch in range(bs): # *: (B, L, ...) -> (1, L2, ...) _sample = { k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]] if k + "_lengths" in batch else batch[k][ibatch, None] for k in keys } # *_lengths: (B,) -> (1,) _sample.update( { k + "_lengths": batch[k + "_lengths"][ibatch, None] for k in keys if k + "_lengths" in batch } ) if "utt_id" in batch: _sample["utt_id"] = batch["utt_id"] model(**_sample) # Derive the attention results for name, output in outputs.items(): if isinstance(output, list): if isinstance(output[0], list): # output: nhead x (Tout, Tin) output = torch.stack( [ # Tout x (1, Tin) -> (Tout, Tin) torch.cat([o[idx] for o in output], dim=0) for idx in range(len(output[0])) ], dim=0, ) else: # Tout x (1, Tin) -> (Tout, Tin) output = torch.cat(output, dim=0) else: # output: (1, NHead, Tout, Tin) -> (NHead, Tout, Tin) output = output.squeeze(0) # output: (Tout, Tin) or (NHead, Tout, Tin) return_dict[name].append(output) outputs.clear() # 3. Remove all hooks for _, handle in handles.items(): handle.remove() return dict(return_dict)
def collect_stats( model: AbsESPnetModel, train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], output_dir: Path, ngpu: Optional[int], log_interval: Optional[int], write_collected_feats: bool, ) -> None: """Perform on collect_stats mode. Running for deriving the shape information from data and gathering statistics. This method is used before executing train(). """ assert check_argument_types() npy_scp_writers = {} for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]): if log_interval is None: try: log_interval = max(len(itr) // 20, 10) except TypeError: log_interval = 100 sum_dict = defaultdict(lambda: 0) sq_dict = defaultdict(lambda: 0) count_dict = defaultdict(lambda: 0) with DatadirWriter(output_dir / mode) as datadir_writer: for iiter, (keys, batch) in enumerate(itr, 1): batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") # 1. Write shape file for name in batch: if name.endswith("_lengths"): continue for i, (key, data) in enumerate(zip(keys, batch[name])): if f"{name}_lengths" in batch: lg = int(batch[f"{name}_lengths"][i]) data = data[:lg] datadir_writer[f"{name}_shape"][key] = ",".join( map(str, data.shape) ) # 2. Extract feats if ngpu <= 1: data = model.collect_feats(**batch) else: # Note that data_parallel can parallelize only "forward()" data = data_parallel( ForwardAdaptor(model, "collect_feats"), (), range(ngpu), module_kwargs=batch, ) # 3. Calculate sum and square sum for key, v in data.items(): for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())): # Truncate zero-padding region if f"{key}_lengths" in data: length = data[f"{key}_lengths"][i] # seq: (Length, Dim, ...) seq = seq[:length] else: # seq: (Dim, ...) -> (1, Dim, ...) seq = seq[None] # Accumulate value, its square, and count sum_dict[key] += seq.sum(0) sq_dict[key] += (seq ** 2).sum(0) count_dict[key] += len(seq) # 4. [Option] Write derived features as npy format file. if write_collected_feats: # Instantiate NpyScpWriter for the first iteration if (key, mode) not in npy_scp_writers: p = output_dir / mode / "collect_feats" npy_scp_writers[(key, mode)] = NpyScpWriter( p / f"data_{key}", p / f"{key}.scp" ) # Save array as npy file npy_scp_writers[(key, mode)][uttid] = seq if iiter % log_interval == 0: logging.info(f"Niter: {iiter}") for key in sum_dict: np.savez( output_dir / mode / f"{key}_stats.npz", count=count_dict[key], sum=sum_dict[key], sum_square=sq_dict[key], ) # batch_keys and stats_keys are used by aggregate_stats_dirs.py with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f: f.write( "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n" ) with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f: f.write("\n".join(sum_dict) + "\n")
def run( cls, model: AbsESPnetModel, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, plot_attention_iter_factory: Optional[AbsIterFactory], trainer_options, distributed_option: DistributedOption, ) -> None: """Perform training. This method performs the main process of training.""" assert check_argument_types() # NOTE(kamo): Don't check the type more strictly as far trainer_options assert is_dataclass(trainer_options), type(trainer_options) assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) if isinstance(trainer_options.keep_nbest_models, int): keep_nbest_models = [trainer_options.keep_nbest_models] else: if len(trainer_options.keep_nbest_models) == 0: logging.warning("No keep_nbest_models is given. Change to [1]") trainer_options.keep_nbest_models = [1] keep_nbest_models = trainer_options.keep_nbest_models output_dir = Path(trainer_options.output_dir) reporter = Reporter() if trainer_options.use_amp: if LooseVersion(torch.__version__) < LooseVersion("1.6.0"): raise RuntimeError( "Require torch>=1.6.0 for Automatic Mixed Precision") if trainer_options.sharded_ddp: if fairscale is None: raise RuntimeError( "Requiring fairscale. Do 'pip install fairscale'") scaler = fairscale.optim.grad_scaler.ShardedGradScaler() else: scaler = GradScaler() else: scaler = None if trainer_options.resume and (output_dir / "checkpoint.pth").exists(): cls.resume( checkpoint=output_dir / "checkpoint.pth", model=model, optimizers=optimizers, schedulers=schedulers, reporter=reporter, scaler=scaler, ngpu=trainer_options.ngpu, ) start_epoch = reporter.get_epoch() + 1 if start_epoch == trainer_options.max_epoch + 1: logging.warning( f"The training has already reached at max_epoch: {start_epoch}" ) if distributed_option.distributed: if trainer_options.sharded_ddp: dp_model = fairscale.nn.data_parallel.ShardedDataParallel( module=model, sharded_optimizer=optimizers, ) else: dp_model = torch.nn.parallel.DistributedDataParallel( model, device_ids=( # Perform multi-Process with multi-GPUs [torch.cuda.current_device()] if distributed_option.ngpu == 1 # Perform single-Process with multi-GPUs else None), output_device=(torch.cuda.current_device() if distributed_option.ngpu == 1 else None), find_unused_parameters=trainer_options.unused_parameters, ) elif distributed_option.ngpu > 1: dp_model = torch.nn.parallel.DataParallel( model, device_ids=list(range(distributed_option.ngpu)), ) else: # NOTE(kamo): DataParallel also should work with ngpu=1, # but for debuggability it's better to keep this block. dp_model = model if trainer_options.use_tensorboard and ( not distributed_option.distributed or distributed_option.dist_rank == 0): from torch.utils.tensorboard import SummaryWriter train_summary_writer = SummaryWriter( str(output_dir / "tensorboard" / "train")) valid_summary_writer = SummaryWriter( str(output_dir / "tensorboard" / "valid")) else: train_summary_writer = None start_time = time.perf_counter() for iepoch in range(start_epoch, trainer_options.max_epoch + 1): if iepoch != start_epoch: logging.info( "{}/{}epoch started. Estimated time to finish: {}".format( iepoch, trainer_options.max_epoch, humanfriendly.format_timespan( (time.perf_counter() - start_time) / (iepoch - start_epoch) * (trainer_options.max_epoch - iepoch + 1)), )) else: logging.info( f"{iepoch}/{trainer_options.max_epoch}epoch started") set_all_random_seed(trainer_options.seed + iepoch) reporter.set_epoch(iepoch) # 1. Train and validation for one-epoch with reporter.observe("train") as sub_reporter: all_steps_are_invalid = cls.train_one_epoch( model=dp_model, optimizers=optimizers, schedulers=schedulers, iterator=train_iter_factory.build_iter(iepoch), reporter=sub_reporter, scaler=scaler, summary_writer=train_summary_writer, options=trainer_options, distributed_option=distributed_option, ) with reporter.observe("valid") as sub_reporter: cls.validate_one_epoch( model=dp_model, iterator=valid_iter_factory.build_iter(iepoch), reporter=sub_reporter, options=trainer_options, distributed_option=distributed_option, ) if not distributed_option.distributed or distributed_option.dist_rank == 0: # att_plot doesn't support distributed if plot_attention_iter_factory is not None: with reporter.observe("att_plot") as sub_reporter: cls.plot_attention( model=model, output_dir=output_dir / "att_ws", summary_writer=train_summary_writer, iterator=plot_attention_iter_factory.build_iter( iepoch), reporter=sub_reporter, options=trainer_options, ) # 2. LR Scheduler step for scheduler in schedulers: if isinstance(scheduler, AbsValEpochStepScheduler): scheduler.step( reporter.get_value( *trainer_options.val_scheduler_criterion)) elif isinstance(scheduler, AbsEpochStepScheduler): scheduler.step() if trainer_options.sharded_ddp: for optimizer in optimizers: if isinstance(optimizer, fairscale.optim.oss.OSS): optimizer.consolidate_state_dict() if not distributed_option.distributed or distributed_option.dist_rank == 0: # 3. Report the results logging.info(reporter.log_message()) if trainer_options.use_matplotlib: reporter.matplotlib_plot(output_dir / "images") if train_summary_writer is not None: reporter.tensorboard_add_scalar(train_summary_writer, key1="train") reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid") if trainer_options.use_wandb: reporter.wandb_log() # 4. Save/Update the checkpoint torch.save( { "model": model.state_dict(), "reporter": reporter.state_dict(), "optimizers": [o.state_dict() for o in optimizers], "schedulers": [ s.state_dict() if s is not None else None for s in schedulers ], "scaler": scaler.state_dict() if scaler is not None else None, }, output_dir / "checkpoint.pth", ) # 5. Save and log the model and update the link to the best model torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") # Creates a sym link latest.pth -> {iepoch}epoch.pth p = output_dir / "latest.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved = [] for _phase, k, _mode in trainer_options.best_model_criterion: # e.g. _phase, k, _mode = "train", "loss", "min" if reporter.has(_phase, k): best_epoch = reporter.get_best_epoch(_phase, k, _mode) # Creates sym links if it's the best result if best_epoch == iepoch: p = output_dir / f"{_phase}.{k}.best.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved.append(f"{_phase}.{k}") if len(_improved) == 0: logging.info("There are no improvements in this epoch") else: logging.info("The best model has been updated: " + ", ".join(_improved)) log_model = (trainer_options.wandb_model_log_interval > 0 and iepoch % trainer_options.wandb_model_log_interval == 0) if log_model and trainer_options.use_wandb: import wandb logging.info("Logging Model on this epoch :::::") artifact = wandb.Artifact( name=f"model_{wandb.run.id}", type="model", metadata={"improved": _improved}, ) artifact.add_file(str(output_dir / f"{iepoch}epoch.pth")) aliases = [ f"epoch-{iepoch}", "best" if best_epoch == iepoch else "", ] wandb.log_artifact(artifact, aliases=aliases) # 6. Remove the model files excluding n-best epoch and latest epoch _removed = [] # Get the union set of the n-best among multiple criterion nbests = set().union(*[ set( reporter.sort_epochs(ph, k, m) [:max(keep_nbest_models)]) for ph, k, m in trainer_options.best_model_criterion if reporter.has(ph, k) ]) # Generated n-best averaged model if (trainer_options.nbest_averaging_interval > 0 and iepoch % trainer_options.nbest_averaging_interval == 0): average_nbest_models( reporter=reporter, output_dir=output_dir, best_model_criterion=trainer_options. best_model_criterion, nbest=keep_nbest_models, suffix=f"till{iepoch}epoch", ) for e in range(1, iepoch): p = output_dir / f"{e}epoch.pth" if p.exists() and e not in nbests: p.unlink() _removed.append(str(p)) if len(_removed) != 0: logging.info("The model files were removed: " + ", ".join(_removed)) # 7. If any updating haven't happened, stops the training if all_steps_are_invalid: logging.warning( f"The gradients at all steps are invalid in this epoch. " f"Something seems wrong. This training was stopped at {iepoch}epoch" ) break # 8. Check early stopping if trainer_options.patience is not None: if reporter.check_early_stopping( trainer_options.patience, *trainer_options.early_stopping_criterion): break else: logging.info( f"The training was finished at {trainer_options.max_epoch} epochs " ) # Generated n-best averaged model if not distributed_option.distributed or distributed_option.dist_rank == 0: average_nbest_models( reporter=reporter, output_dir=output_dir, best_model_criterion=trainer_options.best_model_criterion, nbest=keep_nbest_models, )
def run( cls, model: AbsESPnetModel, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], train_iter_factory: AbsIterFactory, valid_iter_factory: AbsIterFactory, plot_attention_iter_factory: Optional[AbsIterFactory], reporter: Reporter, scaler: Optional[GradScaler], output_dir: Path, max_epoch: int, seed: int, patience: Optional[int], keep_nbest_models: int, early_stopping_criterion: Sequence[str], best_model_criterion: Sequence[Sequence[str]], val_scheduler_criterion: Sequence[str], trainer_options, distributed_option: DistributedOption, ) -> None: """Perform training. This method performs the main process of training.""" assert check_argument_types() # NOTE(kamo): Don't check the type more strictly as far trainer_options assert is_dataclass(trainer_options), type(trainer_options) start_epoch = reporter.get_epoch() + 1 if start_epoch == max_epoch + 1: logging.warning( f"The training has already reached at max_epoch: {start_epoch}" ) if distributed_option.distributed: dp_model = torch.nn.parallel.DistributedDataParallel( model, device_ids=( # Perform multi-Process with multi-GPUs [torch.cuda.current_device()] if distributed_option.ngpu == 1 # Perform single-Process with multi-GPUs else None), output_device=(torch.cuda.current_device() if distributed_option.ngpu == 1 else None), ) elif distributed_option.ngpu > 1: dp_model = torch.nn.parallel.DataParallel( model, device_ids=list(range(distributed_option.ngpu)), ) else: # NOTE(kamo): DataParallel also should work with ngpu=1, # but for debuggability it's better to keep this block. dp_model = model if not distributed_option.distributed or distributed_option.dist_rank == 0: summary_writer = SummaryWriter(str(output_dir / "tensorboard")) else: summary_writer = None start_time = time.perf_counter() for iepoch in range(start_epoch, max_epoch + 1): if iepoch != start_epoch: logging.info( "{}/{}epoch started. Estimated time to finish: {}".format( iepoch, max_epoch, humanfriendly.format_timespan( (time.perf_counter() - start_time) / (iepoch - start_epoch) * (max_epoch - iepoch + 1)), )) else: logging.info(f"{iepoch}/{max_epoch}epoch started") set_all_random_seed(seed + iepoch) reporter.set_epoch(iepoch) # 1. Train and validation for one-epoch with reporter.observe("train") as sub_reporter: all_steps_are_invalid = cls.train_one_epoch( model=dp_model, optimizers=optimizers, schedulers=schedulers, iterator=train_iter_factory.build_iter(iepoch), reporter=sub_reporter, scaler=scaler, summary_writer=summary_writer, options=trainer_options, ) with reporter.observe("valid") as sub_reporter: cls.validate_one_epoch( model=dp_model, iterator=valid_iter_factory.build_iter(iepoch), reporter=sub_reporter, options=trainer_options, ) if not distributed_option.distributed or distributed_option.dist_rank == 0: # att_plot doesn't support distributed if plot_attention_iter_factory is not None: with reporter.observe("att_plot") as sub_reporter: cls.plot_attention( model=model, output_dir=output_dir / "att_ws", summary_writer=summary_writer, iterator=plot_attention_iter_factory.build_iter( iepoch), reporter=sub_reporter, options=trainer_options, ) # 2. LR Scheduler step for scheduler in schedulers: if isinstance(scheduler, AbsValEpochStepScheduler): scheduler.step( reporter.get_value(*val_scheduler_criterion)) elif isinstance(scheduler, AbsEpochStepScheduler): scheduler.step() if not distributed_option.distributed or distributed_option.dist_rank == 0: # 3. Report the results logging.info(reporter.log_message()) reporter.matplotlib_plot(output_dir / "images") reporter.tensorboard_add_scalar(summary_writer) # 4. Save/Update the checkpoint torch.save( { "model": model.state_dict(), "reporter": reporter.state_dict(), "optimizers": [o.state_dict() for o in optimizers], "schedulers": [ s.state_dict() if s is not None else None for s in schedulers ], "scaler": scaler.state_dict() if scaler is not None else None, }, output_dir / "checkpoint.pth", ) # 5. Save the model and update the link to the best model torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") # Creates a sym link latest.pth -> {iepoch}epoch.pth p = output_dir / "latest.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved = [] for _phase, k, _mode in best_model_criterion: # e.g. _phase, k, _mode = "train", "loss", "min" if reporter.has(_phase, k): best_epoch = reporter.get_best_epoch(_phase, k, _mode) # Creates sym links if it's the best result if best_epoch == iepoch: p = output_dir / f"{_phase}.{k}.best.pth" if p.is_symlink() or p.exists(): p.unlink() p.symlink_to(f"{iepoch}epoch.pth") _improved.append(f"{_phase}.{k}") if len(_improved) == 0: logging.info("There are no improvements in this epoch") else: logging.info("The best model has been updated: " + ", ".join(_improved)) # 6. Remove the model files excluding n-best epoch and latest epoch _removed = [] # Get the union set of the n-best among multiple criterion nbests = set().union(*[ set(reporter.sort_epochs(ph, k, m)[:keep_nbest_models]) for ph, k, m in best_model_criterion if reporter.has(ph, k) ]) for e in range(1, iepoch): p = output_dir / f"{e}epoch.pth" if p.exists() and e not in nbests: p.unlink() _removed.append(str(p)) if len(_removed) != 0: logging.info("The model files were removed: " + ", ".join(_removed)) # 7. If any updating haven't happened, stops the training if all_steps_are_invalid: logging.warning( f"The gradients at all steps are invalid in this epoch. " f"Something seems wrong. This training was stopped at {iepoch}epoch" ) break # 8. Check early stopping if patience is not None: if reporter.check_early_stopping(patience, *early_stopping_criterion): break else: logging.info(f"The training was finished at {max_epoch} epochs ")