def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, att_rate: float = 0.0, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['inputs'] feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] supervisions = batch['supervisions'] supervision_segments, texts = encode_supervisions(supervisions) loss_fn = CTCLoss(graph_compiler) grad_context = nullcontext if is_training else torch.no_grad with grad_context(): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] tot_score, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) if is_training: def maybe_log_gradients(tag: str): if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) if att_rate != 0.0: loss = (-(1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-tot_score) / (len(texts) * accum_grad) loss.backward() if is_update: maybe_log_gradients('train/grad_norms') clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and (global_batch_idx_train // accum_grad) % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars( 'train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() optimizer.zero_grad() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, ali_model: Optional[AcousticModel], device: torch.device, graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool, is_training: bool, is_update: bool, accum_grad: int = 1, den_scale: float = 1.0, att_rate: float = 0.0, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None, scaler: GradScaler = None): feature = batch['inputs'] # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) supervisions = batch['supervisions'] supervision_segments, texts = encode_supervisions(supervisions) loss_fn = LFMMILoss(graph_compiler=graph_compiler, den_scale=den_scale, use_pruned_intersect=use_pruned_intersect) grad_context = nullcontext if is_training else torch.no_grad with autocast(enabled=scaler.is_enabled()), grad_context(): if att_rate == 0: # Note: Make TorchScript happy by making the supervision dict strictly # conform to type Dict[str, Tensor] # Using the attention decoder with TorchScript is currently unsupported, # we'll need to separate out the 'text' field from 'supervisions' first. del supervisions['text'] nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: if hasattr(model, 'module'): att_loss = model.module.decoder_forward( encoder_memory, memory_mask, supervisions, graph_compiler) else: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) if (ali_model is not None and global_batch_idx_train is not None and global_batch_idx_train // accum_grad < 4000): with torch.no_grad(): ali_model_output = ali_model(feature) if ali_model_output.isinf().any() or ali_model_output.isnan().any( ): logging.warning( "Found 'nan' or 'inf' in ali_model_output... Setting it to zero." ) ali_model_output[ali_model_output.isinf()] = 0.0 ali_model_output[ali_model_output.isnan()] = 0.0 # subsampling is done slightly differently, may be small length # differences. min_len = min(ali_model_output.shape[2], nnet_output.shape[2]) # scale less than one so it will be encouraged # to mimic ali_model's output ali_model_scale = 500.0 / (global_batch_idx_train // accum_grad + 500) nnet_output = nnet_output.clone( ) # or log-softmax backprop will fail. nnet_output[:, :, : min_len] += ali_model_scale * ali_model_output[:, :, : min_len] # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] mmi_loss, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments) if is_training: def maybe_log_gradients(tag: str): if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) if att_rate != 0.0: loss = (-(1.0 - att_rate) * mmi_loss + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-mmi_loss) / (len(texts) * accum_grad) scaler.scale(loss).backward() if is_update: maybe_log_gradients('train/grad_norms') scaler.unscale_(optimizer) clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and (global_batch_idx_train // accum_grad) % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change( model, optimizer, scaler) tb_writer.add_scalars( 'train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: scaler.step(optimizer) optimizer.zero_grad() scaler.update() ans = -mmi_loss.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, att_rate: float = 0.0, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) else: with torch.no_grad(): nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = target_graph.get_tot_scores( log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: if att_rate != 0.0: loss = (- (1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-tot_score) / (len(texts) * accum_grad) loss.backward() if is_update: clip_grad_value_(model.parameters(), 5.0) optimizer.step() optimizer.zero_grad() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, den_scale: float = 1.0, att_rate: float = 0.0, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) else: with torch.no_grad(): nnet_output, encoder_memory, memory_mask = model( feature, supervisions) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num, den = graph_compiler.compile(texts, P) else: with torch.no_grad(): num, den = graph_compiler.compile(texts, P) assert num.requires_grad == is_training assert den.requires_grad is False num = num.to(device) den = den.to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) num_tot_scores = num.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: def maybe_log_gradients(tag: str): if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) if att_rate != 0.0: loss = (-(1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-tot_score) / (len(texts) * accum_grad) loss.backward() if is_update: maybe_log_gradients('train/grad_norms') clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if (global_batch_idx_train // accum_grad) % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars( 'train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() optimizer.zero_grad() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans