def _intersect_calc_scores_mmi_pruned( self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, ): device = dense_fsa_vec.device assert device == num_graphs.device and device == den_graph.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas num_lats = k2.intersect_dense( a_fsas=num_graphs, b_fsas=dense_fsa_vec, output_beam=self.intersect_conf.output_beam, seqframe_idx_name="seqframe_idx" if return_lats else None, ) den_lats = k2.intersect_dense_pruned( a_fsas=den_graph, b_fsas=dense_fsa_vec, search_beam=self.intersect_conf.search_beam, output_beam=self.intersect_conf.output_beam, min_active_states=self.intersect_conf.min_active_states, max_active_states=self.intersect_conf.max_active_states, seqframe_idx_name="seqframe_idx" if return_lats else None, ) # use_double_scores=True does matter # since otherwise it sometimes makes rounding errors num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) if return_lats: return num_tot_scores, den_tot_scores, num_lats, den_lats else: return num_tot_scores, den_tot_scores, None, None
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] # blank_bias = -3.0 # nnet_output[:, :, 0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert LG.is_cuda() assert HLG.device == nnet_output.device, \ f"Check failed: LG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) # lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) return results
def test_two_dense(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' for device in self.devices: fsa = k2.Fsa.from_str(s).to(device) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 2], [1, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense_pruned(fsa_vec, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000, seqframe_idx_name='seqframe', frame_idx_name='frame') assert torch.all( torch.eq(out_fsa.seqframe, torch.tensor([0, 1, 2, 3, 4, 5, 6], device=device))) assert torch.all( torch.eq(out_fsa.frame, torch.tensor([0, 1, 2, 0, 1, 2, 3], device=device))) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = out_fsa.get_tot_scores(log_semiring=False, use_double_scores=False) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb expected_scores_out_fsa = torch.tensor( [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0], device=device) expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0], device=device) expected_grad_log_prob = torch.tensor([ 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ]).reshape_as(log_prob).to(device) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa, fsa.scores.grad) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def test_two_fsas_long_pruned(self): # as test_two_fsas_long in intersect_dense_test.py, # but with pruned intersection s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.rand((2, 100, 3), dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 1, 95], [1, 20, 50]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) fsa_vec = fsa_vec.to(device) out_fsa = k2.intersect_dense_pruned(fsa_vec, dense_fsa_vec, search_beam=100, output_beam=100, min_active_states=1, max_active_states=10, seqframe_idx_name='seqframe', frame_idx_name='frame') expected_seqframe = torch.arange(96).to(torch.int32).to(device) assert torch.allclose(out_fsa.seqframe, expected_seqframe) # the second output FSA is empty since there is no self-loop in fsa2 assert torch.allclose(out_fsa.frame, expected_seqframe) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = out_fsa.get_tot_scores(log_semiring=False, use_double_scores=False) scores.sum().backward()
def decode(dataloader: torch.utils.data.DataLoader, model: None, device: Union[str, torch.device], ctc_topo: None, numericalizer=None, num_paths=-1, output_beam_size: float=8): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] for batch_idx, batch in enumerate(dataloader): assert isinstance(batch, dict), type(batch) feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output = nnet_output.permute(0, 2, 1) # TODO(Liyong Guo): Tune this bias # blank_bias = 0.0 # nnet_output[:, :, 0] += blank_bias with torch.no_grad(): dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): pieces = [numericalizer.tokens_list[token_id] for token_id in hyps[i]] hyp_words = numericalizer.tokenizer.DecodePieces(pieces).split(' ') ref_words = texts[i].split(' ') results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) return results
def get_objf(batch, model, device, L, symbols, training, optimizer=None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], supervisions['start_frame'], supervisions['num_frames']), 1).to(torch.int32) texts = supervisions['text'] assert feature.ndim == 3 #print(feature.shape) #print(supervision_segments[:, 1] + supervision_segments[:, 2]) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] feature = feature.to(device) if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] # TODO(haowen): create decoding graph at the beginning of training decoding_graph = create_decoding_graph(texts, L, symbols) decoding_graph.to_(device) decoding_graph.scores.requires_grad_(False) #print(nnet_output.shape) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) #dense_fsa_vec.scores.requires_grad_(True) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device #print(nnet_output.get_device()) print(decoding_graph.arcs) print(dense_fsa_vec.dense_fsa_vec) target_graph = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 10, 10000, 0) tot_scores = -k2.get_tot_scores(target_graph, True, False).sum() if training: optimizer.zero_grad() tot_scores.backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() objf = tot_scores.detach().cpu() total_objf = objf.item() total_frames = nnet_output.shape[0] return total_objf, total_frames
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable): results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert LG.is_cuda() assert LG.device == nnet_output.device, \ f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0, 30, 300) best_paths = k2.shortest_path(lattices, use_float_scores=True) best_paths = best_paths.to('cpu') assert best_paths.shape[0] == len(texts) for i in range(len(texts)): hyp_words = [ symbols.get(x) for x in best_paths[i].aux_labels if x > 0 ] results.append((texts[i].split(' '), hyp_words)) if batch_idx % 10 == 0: logging.info('Processed batch {}/{} ({:.6f}%)'.format( batch_idx, len(dataloader), float(batch_idx) / len(dataloader) * 100)) return results
def _compute_mmi_loss_pruned( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' See :func:`_compute_mmi_loss_exact_optimized` for the meaning of the arguments. `pruned` means it uses k2.intersect_dense_pruned Note: It uses the least amount of memory, but the loss is not exact due to pruning. ''' num_graphs, den_graphs = graph_compiler.compile(texts, P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) # the values for search_beam/output_beam/min_active_states/max_active_states # are not tuned. You may want to tune them. den_lats = k2.intersect_dense_pruned(den_graphs, dense_fsa_vec, search_beam=20.0, output_beam=7.0, min_active_states=30, max_active_states=10000) num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def decode( self, cuts: Union[AnyCut, CutSet]) -> List[Tuple[List[str], List[str]]]: """ Perform decoding with an n-gram language model (HLG graph). Doesn't support rescoring at this time. """ if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) word_results = [] # Hacky way to get batch quickly... we may need to improve on this. batch = K2SpeechRecognitionDataset(cuts, input_strategy=OnTheFlyFeatures( self.extractor), check_inputs=False)[list(cuts.ids)] features = batch['inputs'].permute(0, 2, 1).to( self.device) # (B, T, F) -> (B, F, T) supervision_segments, texts = encode_supervisions( batch['supervisions']) # Forward pass through the acoustic model posteriors, _, _ = self.model(features) posteriors = posteriors.permute(0, 2, 1) # (B, F, T) -> (B, T, F) # Wrapping into k2 "dense FSA" (representing PPG as a dense graph) dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments) # The actual decoding starts here: # First, we intersect the HLG and the PPG # with default pruning/beam search params from snowfall # The result is a batch of graphs (lattices) lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8, 30, 10000) # ... then we find the shortest paths in the lattices ... best_paths = k2.shortest_path(lattices, use_double_scores=True) # ... and convert them to words with a convenience wrapper from snowfall hyps = get_texts(best_paths, torch.arange(len(texts))) # Here we read out the words from the best path graphs for i in range(len(texts)): hyp_words = [self.lexicon.words.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') word_results.append((ref_words, hyp_words)) return word_results
def test_simple(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' fsa = k2.Fsa.from_str(s) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa]) log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]], dtype=torch.float32, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense_pruned(fsa_vec, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000) scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0]) expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0]) expected_grad_log_prob = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).reshape_as(log_prob) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa, fsa.scores.grad) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def __call__( self, batch: Dict[str, Union[torch.Tensor, np.ndarray]] ) -> List[Tuple[Optional[str], List[str], List[int], float]]: """Inference Args: batch: Input speech data and corresponding lengths Returns: text, token, token_int, hyp """ assert check_argument_types() if isinstance(batch["speech"], np.ndarray): batch["speech"] = torch.tensor(batch["speech"]) if isinstance(batch["speech_lengths"], np.ndarray): batch["speech_lengths"] = torch.tensor(batch["speech_lengths"]) # a. To device batch = to_device(batch, device=self.device) # b. Forward Encoder # enc: [N, T, C] enc, encoder_out_lens = self.asr_model.encode(**batch) # logp_encoder_output: [N, T, C] logp_encoder_output = torch.nn.functional.log_softmax( self.asr_model.ctc.ctc_lo(enc), dim=2 ) # It maybe useful to tune blank_bias. # The valid range of blank_bias is [-inf, 0] logp_encoder_output[:, :, 0] += self.blank_bias batch_size = encoder_out_lens.size(0) sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(torch.int32) start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t() num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32) supervision_segments = torch.cat([sequence_idx, start_frame, num_frames], dim=1) supervision_segments = supervision_segments.to(torch.int32) # An introduction to DenseFsaVec: # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector # It could be viewed as a fsa-type lopg_encoder_output, # whose weight on the arcs are initialized with logp_encoder_output. # The goal of converting tensor-type to fsa-type is using # fsa related functions in k2. e.g. k2.intersect_dense_pruned below dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments) # The term "intersect" is similar to "compose" in k2. # The differences is are: # for "compose" functions, the composition involves # mathcing output label of a.fsa and input label of b.fsa # while for "intersect" functions, the composition involves # matching input label of a.fsa and input label of b.fsa # Actually, in compose functions, b.fsa is inverted and then # a.fsa and inv_b.fsa are intersected together. # For difference between compose and interset: # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/fsa_algo.py#L308 # For definition of k2.intersect_dense_pruned: # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/autograd.py#L648 lattices = k2.intersect_dense_pruned( self.decode_graph, dense_fsa_vec, self.search_beam_size, self.output_beam_size, self.min_active_states, self.max_active_states, ) # lattices.scores is the sum of decode_graph.scores(a.k.a. lm weight) and # dense_fsa_vec.scores(a.k.a. am weight) on related arcs. # For ctc decoding graph, lattices.scores only store am weight # since the decoder_graph only define the ctc topology and # has no lm weight on its arcs. # While for 3-gram decoding, whose graph is converted from language models, # lattice.scores contains both am weights and lm weights # # It maybe useful to tune lattice.scores # The valid range of lattice_weight is [0, inf) # The lattice_weight will affect the search of k2.random_paths lattices.scores *= self.lattice_weight results = [] if self.use_nbest_rescoring: ( am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits, ) = nbest_am_lm_scores( lattices, self.num_paths, self.device, self.nbest_batch_size ) ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device) max_token_length = max(ys_pad_lens) ys_pad_list = [] for hyp in token_ids: ys_pad_list.append( torch.cat( [ torch.tensor(hyp, dtype=torch.long), torch.tensor( [self.asr_model.ignore_id] * (max_token_length.item() - len(hyp)), dtype=torch.long, ), ] ) ) ys_pad = ( torch.stack(ys_pad_list).to(torch.long).to(self.device) ) # [batch, max_token_length] encoder_out = enc.index_select(0, path_to_seq_map.to(torch.long)).to( self.device ) # [batch, T, dim] encoder_out_lens = encoder_out_lens.index_select( 0, path_to_seq_map.to(torch.long) ).to( self.device ) # [batch] decoder_scores = -self.asr_model.batchify_nll( encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size ) # padded_value for nnlm is 0 ys_pad[ys_pad == self.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = self.lm.batchify_nll( ys_pad, ys_pad_lens, self.nll_batch_size ) nnlm_scores = -nnlm_nll.sum(dim=1) batch_tot_scores = ( self.am_weight * am_scores + self.decoder_weight * decoder_scores + self.nnlm_weight * nnlm_scores ) split_size = indices_to_split_size( seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0) ) batch_tot_scores = torch.split( batch_tot_scores, split_size, ) hyps = [] scores = [] processed_seqs = 0 for tot_scores in batch_tot_scores: if tot_scores.nelement() == 0: # the last element by torch.tensor_split may be empty # e.g. # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4])) # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64)) break best_seq_idx = processed_seqs + torch.argmax(tot_scores) assert best_seq_idx < len(token_ids) best_token_seqs = token_ids[best_seq_idx] processed_seqs += tot_scores.nelement() hyps.append(best_token_seqs) scores.append(tot_scores.max().item()) assert len(hyps) == len(split_size) else: best_paths = k2.shortest_path(lattices, use_double_scores=True) scores = best_paths.get_tot_scores( use_double_scores=True, log_semiring=False ).tolist() hyps = get_texts(best_paths) assert len(scores) == len(hyps) for token_int, score in zip(hyps, scores): # For decoding methods nbest_rescoring and ctc_decoding # hyps stores token_index, which is lattice.labels. # convert token_id to text with self.tokenizer token = self.converter.ids2tokens(token_int) assert self.tokenizer is not None text = self.tokenizer.tokens2text(token) results.append((text, token, token_int, score)) assert check_return_type(results) return results
def decode( dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable, ): num_batches = None try: num_batches = len(dataloader) except TypeError: pass num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch["inputs"] supervisions = batch["supervisions"] supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["num_frames"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions["text"] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] blank_bias = -3.0 nnet_output[:, :, 0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert HLG.is_cuda() assert ( HLG.device == nnet_output.device ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(" ") results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, number of cuts processed until now is {num_cuts}" ) num_cuts += len(texts) return results
def get_loss(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler, is_training: bool, optimizer: Optional[torch.optim.Optimizer] = None): assert P.device == device feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P) else: with torch.no_grad(): num_graph, den_graph, decoding_graph = graph_compiler.compile( texts, P) assert num_graph.requires_grad == is_training assert den_graph.requires_grad is False assert decoding_graph.requires_grad is False assert len( decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None) num_graph = num_graph.to(device) den_graph = den_graph.to(device) decoding_graph = decoding_graph.to(device) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num_lats = k2.intersect_dense(num_graph, dense_fsa_vec, 10.0, seqframe_idx_name='seqframe_idx') mbr_lats = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 20.0, 7.0, 30, 10000, seqframe_idx_name='seqframe_idx') if True: # WARNING: the else branch is not working at present (the total loss is not stable) den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0) else: # in this case, we can remove den_graph den_lats = mbr_lats num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) if id(den_lats) == id(mbr_lats): # Some entries in den_tot_scores may be -inf. # The corresponding sequences are discarded/ignored. finite_indexes = torch.isfinite(den_tot_scores) den_tot_scores = den_tot_scores[finite_indexes] num_tot_scores = num_tot_scores[finite_indexes] else: finite_indexes = None tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2], finite_indexes) num_rows = dense_fsa_vec.scores.shape[0] num_cols = dense_fsa_vec.scores.shape[1] - 1 mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx, cols=num_lats.phones, values=num_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx, cols=mbr_lats.phones, values=mbr_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) # NOTE: Due to limited support of PyTorch's autograd for sparse tensors, # we cannot use (mbr_num_sparse - mbr_den_sparse) here # # The following works only for torch >= 1.7.0 mbr_loss = torch.sparse.sum( k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce())) mmi_loss = -tot_score total_loss = mmi_loss + mbr_loss if is_training: optimizer.zero_grad() total_loss.backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = ( mmi_loss.detach().cpu().item(), mbr_loss.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item(), ) return ans
def __call__( self, speech: Union[torch.Tensor, np.ndarray] ) -> List[Tuple[Optional[str], List[str], List[int], float]]: """Inference Args: data: Input speech data Returns: text, token, token_int, hyp """ assert check_argument_types() # Input as audio signal if isinstance(speech, np.ndarray): speech = torch.tensor(speech) # data: (Nsamples,) -> (1, Nsamples) speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) # lenghts: (1,) lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) batch = {"speech": speech, "speech_lengths": lengths} # a. To device batch = to_device(batch, device=self.device) # b. Forward Encoder # enc: [N, T, C] enc, _ = self.asr_model.encode(**batch) assert len(enc) == 1, len(enc) # logp_encoder_output: [N, T, C] logp_encoder_output = torch.nn.functional.log_softmax( self.asr_model.ctc.ctc_lo(enc), dim=2) # TODO(Liyong Guo): Support batch decoding. # Following statement only support batch_size == 1 supervision_segments = torch.tensor([[0, 0, enc.shape[1]]], dtype=torch.int32) indices = torch.tensor([0]) dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments) lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec, 20.0, self.output_beam_size, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) scores = best_paths.get_tot_scores(use_double_scores=True, log_semiring=False).tolist() hyps = get_texts(best_paths, indices) # TODO(Liyong Guo): Support batch decoding. now batch_size == 1. assert len(scores) == 1 assert len(scores) == len(hyps) results = [] for token_int, score in zip(hyps, scores): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None results.append((text, token, token_int, score)) assert check_return_type(results) return results
def test_two_fsas(self): s1 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 2], [1, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense_pruned(fsa_vec, dense_fsa_vec, search_beam=100000, output_beam=100000, min_active_states=0, max_active_states=10000) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb expected_scores_out_fsa = torch.tensor( [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0]) expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0]) expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0, 1.0]) print("fsa2 is ", fsa2.__str__()) expected_grad_log_prob = torch.tensor([ 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ]).reshape_as(log_prob) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad) assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def decode( self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, return_lattices: bool = False, return_ilabels: bool = False, output_aligned: bool = True, ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]: if self.decoding_graph is None: self.decoding_graph = self.base_graph if self.blank != 0: # rearrange log_probs to put blank at the first place # and shift targets to emulate blank = 0 log_probs, _ = make_blank_first(self.blank, log_probs, None) supervisions, order = create_supervision(log_probs_length) if self.decoding_graph.shape[0] > 1: self.decoding_graph = k2.index_fsa(self.decoding_graph, order).to(device=log_probs.device) if log_probs.device != self.device: self.to(log_probs.device) dense_fsa_vec = ( prep_padded_densefsavec(log_probs, supervisions) if self.pad_fsavec else k2.DenseFsaVec(log_probs, supervisions) ) if self.intersect_pruned: lats = k2.intersect_dense_pruned( a_fsas=self.decoding_graph, b_fsas=dense_fsa_vec, search_beam=self.intersect_conf.search_beam, output_beam=self.intersect_conf.output_beam, min_active_states=self.intersect_conf.min_active_states, max_active_states=self.intersect_conf.max_active_states, ) else: indices = torch.zeros(dense_fsa_vec.dim0(), dtype=torch.int32, device=self.device) dec_graphs = ( k2.index_fsa(self.decoding_graph, indices) if self.decoding_graph.shape[0] == 1 else self.decoding_graph ) lats = k2.intersect_dense(dec_graphs, dense_fsa_vec, self.intersect_conf.output_beam) if self.pad_fsavec: shift_labels_inpl([lats], -1) self.decoding_graph = None if return_lattices: lats = k2.index_fsa(lats, invert_permutation(order).to(device=log_probs.device)) if self.blank != 0: # change only ilabels # suppose self.blank == self.num_classes - 1 lats.labels = torch.where(lats.labels == 0, self.blank, lats.labels - 1) return lats else: shortest_path_fsas = k2.index_fsa( k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device), ) shortest_paths = [] probs = [] # direct iterating does not work as expected for i in range(shortest_path_fsas.shape[0]): shortest_path_fsa = shortest_path_fsas[i] labels = ( shortest_path_fsa.labels[:-1].to(dtype=torch.long) if return_ilabels else shortest_path_fsa.aux_labels[:-1].to(dtype=torch.long) ) if self.blank != 0: # suppose self.blank == self.num_classes - 1 labels = torch.where(labels == 0, self.blank, labels - 1) if not return_ilabels and not output_aligned: labels = labels[labels != self.blank] shortest_paths.append(labels[::2] if self.pad_fsavec else labels) probs.append(get_arc_weights(shortest_path_fsa)[:-1].to(device=log_probs.device).exp()) return shortest_paths, probs
def decode_one_batch(batch: Dict[str, Any], model: AcousticModel, HLG: k2.Fsa, output_beam_size: float, num_paths: int, use_whole_lattice: bool, G: Optional[k2.Fsa] = None) -> Dict[str, List[List[int]]]: ''' Decode one batch and return the result in a dict. The dict has the following format: - key: It indicates the setting used for decoding. For example, if no rescoring is used, the key is the string `no_rescore`. If LM rescoring is used, the key is the string `lm_scale_xxx`, where `xxx` is the value of `lm_scale`. An example key is `lm_scale_0.7` - value: It contains the decoding result. `len(value)` equals to batch size. `value[i]` is the decoding result for the i-th utterance in the given batch. Args: batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. model: The neural network model. HLG: The decoding graph. output_beam_size: Size of the beam for pruning. use_whole_lattice: If True, `G` must not be None and it will use whole lattice for LM rescoring. If False and if `G` is not None, then `num_paths` must be positive and it will use n-best list for LM rescoring. num_paths: It specifies the size of `n` in n-best list decoding. G: The LM. If it is None, no rescoring is used. Otherwise, LM rescoring is used. It supports two types of LM rescoring: n-best list rescoring and whole lattice rescoring. `use_whole_lattice` specifies which type to use. Returns: Return the decoding result. See above description for the format of the returned dict. ''' device = HLG.device feature = batch['inputs'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] supervisions = batch['supervisions'] nnet_output, _, _ = model(feature, supervisions) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) if G is None: if num_paths > 1: best_paths = nbest_decoding(lattices, num_paths) key = f'no_rescore-{num_paths}' else: key = 'no_rescore' best_paths = k2.shortest_path(lattices, use_double_scores=True) hyps = get_texts(best_paths, indices) return {key: hyps} lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if use_whole_lattice: best_paths_dict = rescore_with_whole_lattice(lattices, G, lm_scale_list) else: best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths, lm_scale_list) # best_paths_dict is a dict # - key: lm_scale_xxx, where xxx is the value of lm_scale. An example # key is lm_scale_1.2 # - value: it is the best path obtained using the corresponding lm scale # from the dict key. ans = dict() for lm_scale_str, best_paths in best_paths_dict.items(): hyps = get_texts(best_paths, indices) ans[lm_scale_str] = hyps return ans
def __call__( self, batch: Dict[str, Union[torch.Tensor, np.ndarray]] ) -> List[Tuple[Optional[str], List[str], List[int], float]]: """Inference Args: batch: Input speech data and corresponding lengths Returns: text, token, token_int, hyp """ assert check_argument_types() if isinstance(batch["speech"], np.ndarray): batch["speech"] = torch.tensor(batch["speech"]) if isinstance(batch["speech_lengths"], np.ndarray): batch["speech_lengths"] = torch.tensor(batch["speech_lengths"]) # a. To device batch = to_device(batch, device=self.device) # b. Forward Encoder # enc: [N, T, C] enc, encoder_out_lens = self.asr_model.encode(**batch) # logp_encoder_output: [N, T, C] logp_encoder_output = torch.nn.functional.log_softmax( self.asr_model.ctc.ctc_lo(enc), dim=2) batch_size = encoder_out_lens.size(0) sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to( torch.int32) start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t() num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32) supervision_segments = torch.cat( [sequence_idx, start_frame, num_frames], dim=1) supervision_segments = supervision_segments.to(torch.int32) dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments) lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec, 20.0, self.output_beam_size, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) scores = best_paths.get_tot_scores(use_double_scores=True, log_semiring=False).tolist() hyps = get_texts(best_paths) assert len(scores) == len(hyps) results = [] for token_int, score in zip(hyps, scores): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None results.append((text, token, token_int, score)) assert check_return_type(results) return results
def decode( dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HCLG: Fsa, ): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] # a list of pair [ref_labels, hyp_labels] for batch_idx, batch in enumerate(dataloader): feature = batch["inputs"] # (N, T, C) supervisions = batch["supervisions"] feature = feature.to(device) # Since we are decoding with a k2 graph here, we need to create appropriate # supervisions. The segments need to be ordered in decreasing order of # length (although in our case all segments are of same length) supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["duration"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert HLG.is_cuda() assert ( HCLG.device == nnet_output.device ), f"Check failed: HCLG.device ({HCLG.device}) == nnet_output.device ({nnet_output.device})" lattices = k2.intersect_dense_pruned(HCLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == supervisions["is_voice"].shape[0] # best_paths is an FsaVec, and each of its FSAs is a linear FSA references = supervisions["is_voice"][indices] for i in range(references.shape[0]): ref = references[i, :] hyp = k2.arc_sort( best_paths[i]).arcs_as_tensor()[:-1, 2].detach().cpu() assert ( ref.shape[0] == hyp.shape[0] ), "reference and hypothesis have unequal number of frames, {} vs. {}".format( ref.shape[0], hyp.shape[0]) results.append((supervisions["cut"][indices[i]], ref, hyp)) if batch_idx % 10 == 0: logging.info( "batch {}, cuts processed until now is {}/{} ({:.6f}%)".format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100, )) num_cuts += supervisions["is_voice"].shape[0] return results