def get_validation_objf(dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1): total_objf = 0. total_frames = 0. # for display only total_all_frames = 0. # all frames including those seqs that failed. model.eval() for batch_idx, batch in enumerate(dataloader): objf, frames, all_frames = get_objf(batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=False, is_update=False, den_scale=den_scale) total_objf += objf total_frames += frames total_all_frames += all_frames return total_objf, total_frames, total_all_frames
def get_validation_loss(dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler): total_loss = 0. total_mmi_loss = 0. total_mbr_loss = 0. total_frames = 0. # for display only total_all_frames = 0. # all frames including those seqs that failed. model.eval() for batch_idx, batch in enumerate(dataloader): mmi_loss, mbr_loss, frames, all_frames = get_loss( batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=False) cur_loss = mmi_loss + mbr_loss total_loss += cur_loss total_mmi_loss += mmi_loss total_mbr_loss += mbr_loss total_frames += frames total_all_frames += all_frames return total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames
def get_validation_objf( dataloader: torch.utils.data.DataLoader, model: AcousticModel, ali_model: Optional[AcousticModel], P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, scaler: GradScaler, den_scale: float = 1, ): total_objf = 0. total_frames = 0. # for display only total_all_frames = 0. # all frames including those seqs that failed. model.eval() from torchaudio.datasets.utils import bg_iterator for batch_idx, batch in enumerate(bg_iterator(dataloader, 2)): objf, frames, all_frames = get_objf(batch=batch, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, is_training=False, is_update=False, den_scale=den_scale, scaler=scaler) total_objf += objf total_frames += frames total_all_frames += all_frames return total_objf, total_frames, total_all_frames
def load_checkpoint(filename: Pathlike, model: AcousticModel) -> Dict[str, Any]: logging.info('load checkpoint from {}'.format(filename)) checkpoint = torch.load(filename, map_location='cpu') keys = [ 'state_dict', 'epoch', 'learning_rate', 'objf', 'valid_objf', 'num_features', 'num_classes', 'subsampling_factor', 'global_batch_idx_train' ] missing_keys = set(keys) - set(checkpoint.keys()) if missing_keys: raise ValueError(f"Missing keys in checkpoint: {missing_keys}") if not list(model.state_dict().keys())[0].startswith('module.') \ and list(checkpoint['state_dict'])[0].startswith('module.'): # the checkpoint was saved by DDP logging.info('load checkpoint from DDP') dst_state_dict = model.state_dict() src_state_dict = checkpoint['state_dict'] for key in dst_state_dict.keys(): src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict) else: model.load_state_dict(checkpoint['state_dict']) model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] model.subsampling_factor = checkpoint['subsampling_factor'] return checkpoint
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: TrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, num_epochs: int): total_objf, total_frames, total_all_frames = 0., 0., 0. time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \ get_objf(batch, model, device, graph_compiler, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler) model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(total_valid_objf / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) prev_timestamp = datetime.now() return total_objf
def load_checkpoint( filename: Pathlike, model: AcousticModel, optimizer: Optional[object] = None, scheduler: Optional[object] = None, scaler: Optional[GradScaler] = None, ) -> Dict[str, Any]: logging.info('load checkpoint from {}'.format(filename)) checkpoint = torch.load(filename, map_location='cpu') keys = [ 'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate', 'objf', 'valid_objf', 'num_features', 'num_classes', 'subsampling_factor', 'global_batch_idx_train' ] missing_keys = set(keys) - set(checkpoint.keys()) if missing_keys: raise ValueError(f"Missing keys in checkpoint: {missing_keys}") if isinstance(model, DistributedDataParallel): model = model.module if not list(model.state_dict().keys())[0].startswith('module.') \ and list(checkpoint['state_dict'])[0].startswith('module.'): # the checkpoint was saved by DDP logging.info('load checkpoint from DDP') dst_state_dict = model.state_dict() src_state_dict = checkpoint['state_dict'] for key in dst_state_dict.keys(): src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict, strict=False) else: model.load_state_dict(checkpoint['state_dict'], strict=False) # Note we used strict=False above so that the current code # can load models trained with P_scores. model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] model.subsampling_factor = checkpoint['subsampling_factor'] if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer']) if scheduler is not None: scheduler.load_state_dict(checkpoint['scheduler']) if scaler is not None: scaler.load_state_dict(checkpoint['grad_scaler']) return checkpoint
def average_checkpoint(filenames: List[Pathlike], model: AcousticModel) -> Dict[str, Any]: logging.info('average over checkpoints {}'.format(filenames)) avg_model = None # sum for filename in filenames: checkpoint = torch.load(filename, map_location='cpu') checkpoint_model = checkpoint['state_dict'] if avg_model is None: avg_model = checkpoint_model else: for k in avg_model.keys(): avg_model[k] += checkpoint_model[k] # average for k in avg_model.keys(): if avg_model[k] is not None: if avg_model[k].is_floating_point(): avg_model[k] /= len(filenames) else: avg_model[k] //= len(filenames) checkpoint['state_dict'] = avg_model keys = [ 'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate', 'objf', 'valid_objf', 'num_features', 'num_classes', 'subsampling_factor', 'global_batch_idx_train' ] missing_keys = set(keys) - set(checkpoint.keys()) if missing_keys: raise ValueError(f"Missing keys in checkpoint: {missing_keys}") if not list(model.state_dict().keys())[0].startswith('module.') \ and list(checkpoint['state_dict'])[0].startswith('module.'): # the checkpoint was saved by DDP logging.info('load checkpoint from DDP') dst_state_dict = model.state_dict() src_state_dict = checkpoint['state_dict'] for key in dst_state_dict.keys(): src_key = '{}.{}'.format('module', key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 model.load_state_dict(dst_state_dict, strict=False) else: model.load_state_dict(checkpoint['state_dict'], strict=False) model.num_features = checkpoint['num_features'] model.num_classes = checkpoint['num_classes'] model.subsampling_factor = checkpoint['subsampling_factor'] return checkpoint
def get_validation_objf( dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, ): total_objf = 0.0 total_frames = 0.0 # for display only model.eval() for batch_idx, batch in enumerate(dataloader): objf, frames = get_objf(batch, model, device, False) total_objf += objf total_frames += frames return total_objf, total_frames
def save_checkpoint(filename: Pathlike, model: AcousticModel, epoch: int, learning_rate: float, objf: float, valid_objf: float, global_batch_idx_train: int, local_rank: int = 0) -> None: if local_rank is not None and local_rank != 0: return logging.info( f'Save checkpoint to {filename}: epoch={epoch}, ' f'learning_rate={learning_rate}, objf={objf}, valid_objf={valid_objf}') checkpoint = { 'state_dict': model.state_dict(), 'num_features': model.num_features, 'num_classes': model.num_classes, 'subsampling_factor': model.subsampling_factor, 'epoch': epoch, 'learning_rate': learning_rate, 'objf': objf, 'valid_objf': valid_objf, 'global_batch_idx_train': global_batch_idx_train, } torch.save(checkpoint, filename)
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['inputs'] # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) supervisions = batch['supervisions'] supervision_segments, texts = encode_supervisions(supervisions, model.subsampling_factor) loss_fn = LFMMILoss(graph_compiler=graph_compiler, P=P, den_scale=den_scale) grad_context = nullcontext if is_training else torch.no_grad with grad_context(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) if is_training: def maybe_log_gradients(tag: str): if (tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0): tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) optimizer.zero_grad() (-mmi_loss).backward() maybe_log_gradients('train/grad_norms') clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and global_batch_idx_train % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars('train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_validation_objf(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler): total_objf = 0. total_frames = 0. # for display only total_all_frames = 0. # all frames including those seqs that failed. model.eval() for batch_idx, batch in enumerate(dataloader): objf, frames, all_frames = get_objf(batch, model, device, graph_compiler, False) total_objf += objf total_frames += frames total_all_frames += all_frames return total_objf, total_frames, total_all_frames
def get_objf( batch: Dict, model: AcousticModel, device: torch.device, training: bool, optimizer: Optional[torch.optim.Optimizer] = None, class_weights: Optional[torch.Tensor] = None, ): feature = batch["inputs"] # (N, T, C) supervisions = batch["supervisions"]["is_voice"].unsqueeze( -1).long() # (N, T, 1) feature = feature.to(device) supervisions = supervisions.to(device) if class_weights is not None: class_weights = class_weights.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] # Compute cross-entropy loss xent_loss = torch.nn.CrossEntropyLoss(reduction="sum", weight=class_weights) tot_score = xent_loss(nnet_output.contiguous().view(-1, 2), supervisions.contiguous().view(-1)) if training: optimizer.zero_grad() tot_score.backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step(), ans = ( tot_score.detach().cpu().item(), # total objective function value supervisions.numel(), # number of frames ) return ans
def save_checkpoint(filename: Pathlike, model: AcousticModel, epoch: int, learning_rate: float, objf: float, local_rank: int = 0) -> None: if local_rank is not None and local_rank != 0: return logging.info('Save checkpoint to {filename}: epoch={epoch}, ' 'learning_rate={learning_rate}, objf={objf}'.format( filename=filename, epoch=epoch, learning_rate=learning_rate, objf=objf)) checkpoint = { 'state_dict': model.state_dict(), 'num_features': model.num_features, 'num_classes': model.num_classes, 'subsampling_factor': model.subsampling_factor, 'epoch': epoch, 'learning_rate': learning_rate, 'objf': objf } torch.save(checkpoint, filename)
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num, den = graph_compiler.compile(texts, P) else: with torch.no_grad(): num, den = graph_compiler.compile(texts, P) assert num.requires_grad == is_training assert den.requires_grad is False num = num.to(device) den = den.to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) num_tot_scores = num.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, ali_model: Optional[AcousticModel], P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, den_scale: float = 1.0, att_rate: float = 0.0, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None, scaler: GradScaler = None): feature = batch['inputs'] # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) supervisions = batch['supervisions'] supervision_segments, texts = encode_supervisions(supervisions) loss_fn = LFMMILoss(graph_compiler=graph_compiler, P=P, den_scale=den_scale) grad_context = nullcontext if is_training else torch.no_grad with autocast(enabled=scaler.is_enabled()), grad_context(): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: att_loss = model.module.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) if (ali_model is not None and global_batch_idx_train is not None and global_batch_idx_train // accum_grad < 4000): with torch.no_grad(): ali_model_output = ali_model(feature) # subsampling is done slightly differently, may be small length # differences. min_len = min(ali_model_output.shape[2], nnet_output.shape[2]) # scale less than one so it will be encouraged # to mimic ali_model's output ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad + 500) nnet_output = nnet_output.clone( ) # or log-softmax backprop will fail. nnet_output[:, :, : min_len] += ali_model_scale * ali_model_output[:, :, : min_len] # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) if is_training: def maybe_log_gradients(tag: str): if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) if att_rate != 0.0: loss = (-(1.0 - att_rate) * mmi_loss + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-mmi_loss) / (len(texts) * accum_grad) scaler.scale(loss).backward() if is_update: maybe_log_gradients('train/grad_norms') scaler.unscale_(optimizer) clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and (global_batch_idx_train // accum_grad) % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change( model, optimizer, scaler) tb_writer.add_scalars( 'train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: scaler.step(optimizer) optimizer.zero_grad() scaler.update() ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, optimizer: torch.optim.Optimizer, accum_grad: int, att_rate: float, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int): """One epoch training and validation. Args: dataloader: Training dataloader valid_dataloader: Validation dataloader model: Acoustic model to be trained device: Training device, torch.device("cpu") or torch.device("cuda", device_id) graph_compiler: MMI training graph compiler optimizer: Training optimizer accum_grad: Number of gradient accumulation att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss current_epoch: current training epoch, for logging only tb_writer: tensorboard SummaryWriter num_epochs: total number of training epochs, for logging only global_batch_idx_train: global training batch index before this epoch, for logging only Returns: A tuple of 3 scalar: (total_objf / total_frames, valid_average_objf, global_batch_idx_train) - `total_objf / total_frames` is the average training loss - `valid_average_objf` is the average validation loss - `global_batch_idx_train` is the global training batch index after this epoch """ total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 forward_count = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): forward_count += 1 if forward_count == accum_grad: is_update = True forward_count = 0 else: is_update = False global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, device=device, graph_compiler=graph_compiler, is_training=True, is_update=is_update, accum_grad=accum_grad, att_rate=att_rate, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) if tb_writer is not None: tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler) valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) model.write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def get_loss(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler, is_training: bool, optimizer: Optional[torch.optim.Optimizer] = None): assert P.device == device feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P) else: with torch.no_grad(): num_graph, den_graph, decoding_graph = graph_compiler.compile( texts, P) assert num_graph.requires_grad == is_training assert den_graph.requires_grad is False assert decoding_graph.requires_grad is False assert len( decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None) num_graph = num_graph.to(device) den_graph = den_graph.to(device) decoding_graph = decoding_graph.to(device) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num_lats = k2.intersect_dense(num_graph, dense_fsa_vec, 10.0, seqframe_idx_name='seqframe_idx') mbr_lats = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 20.0, 7.0, 30, 10000, seqframe_idx_name='seqframe_idx') if True: # WARNING: the else branch is not working at present (the total loss is not stable) den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0) else: # in this case, we can remove den_graph den_lats = mbr_lats num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) if id(den_lats) == id(mbr_lats): # Some entries in den_tot_scores may be -inf. # The corresponding sequences are discarded/ignored. finite_indexes = torch.isfinite(den_tot_scores) den_tot_scores = den_tot_scores[finite_indexes] num_tot_scores = num_tot_scores[finite_indexes] else: finite_indexes = None tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2], finite_indexes) num_rows = dense_fsa_vec.scores.shape[0] num_cols = dense_fsa_vec.scores.shape[1] - 1 mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx, cols=num_lats.phones, values=num_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx, cols=mbr_lats.phones, values=mbr_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) # NOTE: Due to limited support of PyTorch's autograd for sparse tensors, # we cannot use (mbr_num_sparse - mbr_den_sparse) here # # The following works only for torch >= 1.7.0 mbr_loss = torch.sparse.sum( k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce())) mmi_loss = -tot_score total_loss = mmi_loss + mbr_loss if is_training: optimizer.zero_grad() total_loss.backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = ( mmi_loss.detach().cpu().item(), mbr_loss.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item(), ) return ans
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, global_batch_idx_valid: int): total_objf, total_frames, total_all_frames = 0., 0., 0. time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() ragged_shape = P.arcs.shape().to(device) for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() P.set_scores_stochastic_(model.P_scores) assert P.is_cpu assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = \ get_objf(batch, model, P, device, graph_compiler, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) global_batch_idx_valid += 1 model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(total_valid_objf / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_objf', total_valid_objf / total_valid_frames, global_batch_idx_valid) prev_timestamp = datetime.now() return total_objf / total_frames
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, ali_model: Optional[AcousticModel], P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool, optimizer: torch.optim.Optimizer, accum_grad: int, den_scale: float, att_rate: float, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, world_size: int, scaler: GradScaler ): """One epoch training and validation. Args: dataloader: Training dataloader valid_dataloader: Validation dataloader model: Acoustic model to be trained P: An FSA representing the bigram phone LM device: Training device, torch.device("cpu") or torch.device("cuda", device_id) graph_compiler: MMI training graph compiler optimizer: Training optimizer accum_grad: Number of gradient accumulation den_scale: Denominator scale in mmi loss att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss current_epoch: current training epoch, for logging only tb_writer: tensorboard SummaryWriter num_epochs: total number of training epochs, for logging only global_batch_idx_train: global training batch index before this epoch, for logging only Returns: A tuple of 3 scalar: (total_objf / total_frames, valid_average_objf, global_batch_idx_train) - `total_objf / total_frames` is the average training loss - `valid_average_objf` is the average validation loss - `global_batch_idx_train` is the global training batch index after this epoch """ total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 forward_count = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): forward_count += 1 if forward_count == accum_grad: is_update = True forward_count = 0 else: is_update = False global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() if forward_count == 1 or accum_grad == 1: P.set_scores_stochastic_(model.module.P_scores) assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, is_training=True, is_update=is_update, accum_grad=accum_grad, den_scale=den_scale, att_rate=att_rate, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer, scaler=scaler ) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) if tb_writer is not None: tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, ali_model=ali_model, P=P, device=device, graph_compiler=graph_compiler, use_pruned_intersect=use_pruned_intersect, scaler=scaler) if world_size > 1: s = torch.tensor([ total_valid_objf, total_valid_frames, total_valid_all_frames ]).to(device) dist.all_reduce(s, op=dist.ReduceOp.SUM) total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu().tolist() valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) try: subsampling_factor = model.subsampling_factor except: subsampling_factor = model.module.subsampling_factor supervisions = batch['supervisions'] supervision_segments, texts = encode_supervisions(supervisions, subsampling_factor) loss_fn = LFMMILoss(graph_compiler=graph_compiler, den_scale=den_scale) grad_context = nullcontext if is_training else torch.no_grad with grad_context(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) if is_training: def maybe_log_gradients(tag: str): if (tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0): tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) optimizer.zero_grad() (-mmi_loss).backward() for name, param in model.named_parameters(): if param.grad is None: print(name) maybe_log_gradients('train/grad_norms') #clip_grad_value_(model.parameters(), 5.0) clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and global_batch_idx_train % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars('train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int): total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, device=device, graph_compiler=graph_compiler, is_training=True, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) print( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler) valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format(valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) model.write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int): total_loss, total_mmi_loss, total_mbr_loss, total_frames, total_all_frames = 0., 0., 0., 0., 0. valid_average_loss = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() P.set_scores_stochastic_(model.P_scores) assert P.requires_grad is True curr_batch_mmi_loss, curr_batch_mbr_loss, curr_batch_frames, curr_batch_all_frames = get_loss( batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=True, optimizer=optimizer) total_mmi_loss += curr_batch_mmi_loss total_mbr_loss += curr_batch_mbr_loss curr_batch_loss = curr_batch_mmi_loss + curr_batch_mbr_loss total_loss += curr_batch_loss total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info('batch {}, epoch {}/{} ' 'global average loss: {:.6f}, ' 'global average mmi loss: {:.6f}, ' 'global average mbr loss: {:.6f} over {} ' 'frames ({:.1f}% kept), ' 'current batch average loss: {:.6f}, ' 'current batch average mmi loss: {:.6f}, ' 'current batch average mbr loss: {:.6f} ' 'over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_loss / total_frames, total_mmi_loss / total_frames, total_mbr_loss / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_loss / (curr_batch_frames + 0.001), curr_batch_mmi_loss / (curr_batch_frames + 0.001), curr_batch_mbr_loss / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_loss', total_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_average_mmi_loss', total_mmi_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_average_mbr_loss', total_mbr_loss / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_loss', curr_batch_loss / (curr_batch_frames + 0.001), global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_mmi_loss', curr_batch_mmi_loss / (curr_batch_frames + 0.001), global_batch_idx_train) tb_writer.add_scalar( 'train/current_batch_average_mbr_loss', curr_batch_mbr_loss / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 3000 == 0: total_valid_loss, total_valid_mmi_loss, total_valid_mbr_loss, \ total_valid_frames, total_valid_all_frames = get_validation_loss( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) valid_average_loss = total_valid_loss / total_valid_frames model.train() logging.info('Validation average loss: {:.6f}, ' 'Validation average mmi loss: {:.6f}, ' 'Validation average mbr loss: {:.6f} ' 'over {} frames ({:.1f}% kept)'.format( total_valid_loss / total_valid_frames, total_valid_mmi_loss / total_valid_frames, total_valid_mbr_loss / total_valid_frames, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) tb_writer.add_scalar('train/global_valid_average_loss', total_valid_loss / total_valid_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_valid_average_mmi_loss', total_valid_mmi_loss / total_valid_frames, global_batch_idx_train) tb_writer.add_scalar('train/global_valid_average_mbr_loss', total_valid_mbr_loss / total_valid_frames, global_batch_idx_train) prev_timestamp = datetime.now() return total_loss / total_frames, valid_average_loss, global_batch_idx_train
def train_one_epoch( dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: SummaryWriter, num_epochs: int, global_batch_idx_train: int, ): total_objf, total_frames, total_all_frames = 0.0, 0.0, 0.0 valid_average_objf = float("inf") time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch, model, device, graph_compiler, True, optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0: logging.info( "batch {}, epoch {}/{} " "global average objf: {:.6f} over {} " "frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) " "avg time waiting for batch {:.3f}s".format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx), )) tb_writer.add_scalar( "train/global_average_objf", total_objf / total_frames, global_batch_idx_train, ) tb_writer.add_scalar( "train/current_batch_average_objf", curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train, ) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 200 == 0: ( total_valid_objf, total_valid_frames, total_valid_all_frames, ) = get_validation_objf( dataloader=valid_dataloader, model=model, device=device, graph_compiler=graph_compiler, ) valid_average_objf = total_valid_objf / total_valid_frames model.train() logging.info( "Validation average objf: {:.6f} over {} frames ({:.1f}% kept)" .format( valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames, )) tb_writer.add_scalar( "train/global_valid_average_objf", valid_average_objf, global_batch_idx_train, ) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, att_rate: float = 0.0, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) else: with torch.no_grad(): nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = target_graph.get_tot_scores( log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: if att_rate != 0.0: loss = (- (1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-tot_score) / (len(texts) * accum_grad) loss.backward() if is_update: clip_grad_value_(model.parameters(), 5.0) optimizer.step() optimizer.zero_grad() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf( batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, training: bool, optimizer: Optional[torch.optim.Optimizer] = None, ): feature = batch["inputs"] supervisions = batch["supervisions"] supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["num_frames"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions["text"] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = ( -tot_score.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item(), ) return ans
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, att_rate: float = 0.0, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['inputs'] feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] supervisions = batch['supervisions'] supervision_segments, texts = encode_supervisions(supervisions) loss_fn = CTCLoss(graph_compiler) grad_context = nullcontext if is_training else torch.no_grad with grad_context(): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] tot_score, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) if is_training: def maybe_log_gradients(tag: str): if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) if att_rate != 0.0: loss = (-(1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-tot_score) / (len(texts) * accum_grad) loss.backward() if is_update: maybe_log_gradients('train/grad_norms') clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and (global_batch_idx_train // accum_grad) % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars( 'train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() optimizer.zero_grad() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def train_one_epoch(dataloader: torch.utils.data.DataLoader, valid_dataloader: torch.utils.data.DataLoader, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, optimizer: torch.optim.Optimizer, current_epoch: int, tb_writer: Optional[SummaryWriter], num_epochs: int, global_batch_idx_train: int): total_objf, total_frames, total_all_frames = 0., 0., 0. valid_average_objf = float('inf') time_waiting_for_batch = 0 prev_timestamp = datetime.now() model.train() for batch_idx, batch in enumerate(dataloader): global_batch_idx_train += 1 timestamp = datetime.now() time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() if isinstance(model, DDP): P.set_scores_stochastic_(model.module.P_scores) else: P.set_scores_stochastic_(model.P_scores) assert P.is_cpu assert P.requires_grad is True curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( batch=batch, model=model, P=P, device=device, graph_compiler=graph_compiler, is_training=True, tb_writer=tb_writer, global_batch_idx_train=global_batch_idx_train, optimizer=optimizer) total_objf += curr_batch_objf total_frames += curr_batch_frames total_all_frames += curr_batch_all_frames if batch_idx % 10 == 0 and dist.get_rank() == 0: logging.info( 'batch {}, epoch {}/{} ' 'global average objf: {:.6f} over {} ' 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' 'avg time waiting for batch {:.3f}s'.format( batch_idx, current_epoch, num_epochs, total_objf / total_frames, total_frames, 100.0 * total_frames / total_all_frames, curr_batch_objf / (curr_batch_frames + 0.001), curr_batch_frames, 100.0 * curr_batch_frames / curr_batch_all_frames, time_waiting_for_batch / max(1, batch_idx))) tb_writer.add_scalar('train/global_average_objf', total_objf / total_frames, global_batch_idx_train) tb_writer.add_scalar('train/current_batch_average_objf', curr_batch_objf / (curr_batch_frames + 0.001), global_batch_idx_train) # if batch_idx >= 10: # print("Exiting early to get profile info") # sys.exit(0) if batch_idx > 0 and batch_idx % 1000 == 0: total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( dataloader=valid_dataloader, model=model, P=P, device=device, graph_compiler=graph_compiler) # Synchronize the loss to the master node so that we display it correctly. # dist.reduce performs sum reduction by default. valid_average_objf = total_valid_objf / total_valid_frames model.train() if dist.get_rank() == 0: logging.info( 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' .format( valid_average_objf, total_valid_frames, 100.0 * total_valid_frames / total_valid_all_frames)) if tb_writer is not None: tb_writer.add_scalar('train/global_valid_average_objf', valid_average_objf, global_batch_idx_train) (model.module if isinstance(model, DDP) else model).write_tensorboard_diagnostics( tb_writer, global_step=global_batch_idx_train) prev_timestamp = datetime.now() return total_objf / total_frames, valid_average_objf, global_batch_idx_train
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, training: bool, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = k2.get_tot_scores(target_graph, log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] subsampling_factor = model.module.subsampling_factor if isinstance( model, DDP) else model.subsampling_factor supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], subsampling_factor), torch.floor_divide(supervisions['num_frames'], subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num, den = graph_compiler.compile(texts, P) else: with torch.no_grad(): num, den = graph_compiler.compile(texts, P) assert num.requires_grad == is_training assert den.requires_grad is False num = num.to(device) den = den.to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) num_tot_scores = num.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: def maybe_log_gradients(tag: str): if (tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0): tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) optimizer.zero_grad() (-tot_score).backward() maybe_log_gradients('train/grad_norms') clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and global_batch_idx_train % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars('train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans