def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, training: bool, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, att_rate: float = 0.0, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output, encoder_memory, memory_mask = model(feature, supervisions) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) else: with torch.no_grad(): nnet_output, encoder_memory, memory_mask = model( feature, supervisions) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: if att_rate != 0.0: loss = (-(1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-tot_score) / (len(texts) * accum_grad) loss.backward() if is_update: clip_grad_value_(model.parameters(), 5.0) optimizer.step() optimizer.zero_grad() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans