def test_single_fsa(self): s = ''' 0 4 1 1 0 1 1 1 1 2 1 2 1 3 1 3 2 7 1 4 3 7 1 5 4 6 1 2 4 8 1 3 5 9 -1 4 6 9 -1 3 7 9 -1 5 8 9 -1 6 9 ''' for device in self.devices: fsa = k2.Fsa.from_str(s).to(device) fsa = k2.create_fsa_vec([fsa]) fsa.requires_grad_(True) best_path = k2.shortest_path(fsa, use_double_scores=False) # we recompute the total_scores for backprop total_scores = best_path.scores.sum() assert total_scores == 14 expected = torch.zeros(12) expected[torch.tensor([1, 3, 5, 10])] = 1 total_scores.backward() assert torch.allclose(fsa.scores.grad, expected.to(device))
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys]) transcripts_lexicon = k2.intersect(transcripts, L_inv) transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon)) transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] # blank_bias = -3.0 # nnet_output[:, :, 0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert LG.is_cuda() assert HLG.device == nnet_output.device, \ f"Check failed: LG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) # lattices = k2.intersect_dense(LG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) return results
def decode(dataloader: torch.utils.data.DataLoader, model: None, device: Union[str, torch.device], ctc_topo: None, numericalizer=None, num_paths=-1, output_beam_size: float=8): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] for batch_idx, batch in enumerate(dataloader): assert isinstance(batch, dict), type(batch) feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output = nnet_output.permute(0, 2, 1) # TODO(Liyong Guo): Tune this bias # blank_bias = 0.0 # nnet_output[:, :, 0] += blank_bias with torch.no_grad(): dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): pieces = [numericalizer.tokens_list[token_id] for token_id in hyps[i]] hyp_words = numericalizer.tokenizer.DecodePieces(pieces).split(' ') ref_words = texts[i].split(' ') results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) return results
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable): results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert LG.is_cuda() assert LG.device == nnet_output.device, \ f"Check failed: LG.device ({LG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(LG, dense_fsa_vec, 2000.0, 20.0, 30, 300) best_paths = k2.shortest_path(lattices, use_float_scores=True) best_paths = best_paths.to('cpu') assert best_paths.shape[0] == len(texts) for i in range(len(texts)): hyp_words = [ symbols.get(x) for x in best_paths[i].aux_labels if x > 0 ] results.append((texts[i].split(' '), hyp_words)) if batch_idx % 10 == 0: logging.info('Processed batch {}/{} ({:.6f}%)'.format( batch_idx, len(dataloader), float(batch_idx) / len(dataloader) * 100)) return results
def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: """ Perform forced alignment and return a tensor that represents a batch of frame-level alignments: >>> alignments = torch.tensor([ ... [0, 0, 0, 1, 57, 57, 35, 35, 35, ...], ... [...], ... ... ... ]) :return: an int32 tensor with shape ``(batch_size, num_frames)``. """ # Extract feats # (batch, seq_len, num_feats) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' cuts = cuts.map_supervisions(self.normalize_text) otf = OnTheFlyFeatures(self.extractor) feats, _ = otf(cuts) feats = feats.permute(0, 2, 1) texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts] # Compute AM posteriors # (batch, seq_len ~/ 4, num_phones) posteriors, _, _ = self.model(feats) # Note: we are using "dummy" supervisions so that the aligner also considers # the padding area. We can adjust that behaviour if needed by passing actual # supervision segments, but then we will have a ragged tensor (will need to # pad the alignments themselves). sups = self.dummy_supervisions(feats) posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups) # Intersection with ground truth transcript graphs num, den = self.compiler.compile(texts, self.P) alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0) best_path = k2.shortest_path(alignment, use_double_scores=True) # Retrieve sequences of phone IDs per frame # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels) frame_labels = torch.stack( [best_path[i].labels[:-1] for i in range(best_path.shape[0])]) return frame_labels
def decode( self, cuts: Union[AnyCut, CutSet]) -> List[Tuple[List[str], List[str]]]: """ Perform decoding with an n-gram language model (HLG graph). Doesn't support rescoring at this time. """ if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) word_results = [] # Hacky way to get batch quickly... we may need to improve on this. batch = K2SpeechRecognitionDataset(cuts, input_strategy=OnTheFlyFeatures( self.extractor), check_inputs=False)[list(cuts.ids)] features = batch['inputs'].permute(0, 2, 1).to( self.device) # (B, T, F) -> (B, F, T) supervision_segments, texts = encode_supervisions( batch['supervisions']) # Forward pass through the acoustic model posteriors, _, _ = self.model(features) posteriors = posteriors.permute(0, 2, 1) # (B, F, T) -> (B, T, F) # Wrapping into k2 "dense FSA" (representing PPG as a dense graph) dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments) # The actual decoding starts here: # First, we intersect the HLG and the PPG # with default pruning/beam search params from snowfall # The result is a batch of graphs (lattices) lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8, 30, 10000) # ... then we find the shortest paths in the lattices ... best_paths = k2.shortest_path(lattices, use_double_scores=True) # ... and convert them to words with a convenience wrapper from snowfall hyps = get_texts(best_paths, torch.arange(len(texts))) # Here we read out the words from the best path graphs for i in range(len(texts)): hyp_words = [self.lexicon.words.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') word_results.append((ref_words, hyp_words)) return word_results
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) device = L_inv.device transcripts = k2.create_fsa_vec( [k2.linear_fsa(x, device=device) for x in ys]) transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) transcripts_lexicon = k2.intersect(L_inv, transcripts_with_self_loops, treat_epsilons_specially=False) # Don't call invert_() above because we want to return phone IDs, # which is the `aux_labels` of transcripts_lexicon transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.top_sort(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def intersect(self, lats: Fsa) -> 'Nbest': '''Intersect this Nbest object with a lattice and get 1-best path from the resulting FsaVec. Caution: We assume FSAs in `self.fsa` don't have epsilon self-loops. We also assume `self.fsa.labels` and `lats.labels` are token IDs. Args: lats: An FsaVec. It can be the return value of :func:`whole_lattice_rescoring`. Returns: Return a new Nbest. This new Nbest shares the same shape with `self`, while its `fsa` is the 1-best path from intersecting `self.fsa` and `lats. ''' assert self.fsa.device == lats.device, \ f'{self.fsa.device} vs {lats.device}' assert len(lats.shape) == 3, f'{lats.shape}' assert lats.arcs.dim0() == self.shape.dim0(), \ f'{lats.arcs.dim0()} vs {self.shape.dim0()}' lats = k2.arc_sort(lats) # no-op if lats is already arc sorted fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa) path_to_seq_map = self.shape.row_ids(1) ans_lats = k2.intersect_device(a_fsas=lats, b_fsas=fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) one_best = k2.shortest_path(ans_lats, use_double_scores=True) one_best = k2.remove_epsilon(one_best) return Nbest(fsa=one_best, shape=self.shape)
def __call__( self, batch: Dict[str, Union[torch.Tensor, np.ndarray]] ) -> List[Tuple[Optional[str], List[str], List[int], float]]: """Inference Args: batch: Input speech data and corresponding lengths Returns: text, token, token_int, hyp """ assert check_argument_types() if isinstance(batch["speech"], np.ndarray): batch["speech"] = torch.tensor(batch["speech"]) if isinstance(batch["speech_lengths"], np.ndarray): batch["speech_lengths"] = torch.tensor(batch["speech_lengths"]) # a. To device batch = to_device(batch, device=self.device) # b. Forward Encoder # enc: [N, T, C] enc, encoder_out_lens = self.asr_model.encode(**batch) # logp_encoder_output: [N, T, C] logp_encoder_output = torch.nn.functional.log_softmax( self.asr_model.ctc.ctc_lo(enc), dim=2) batch_size = encoder_out_lens.size(0) sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to( torch.int32) start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t() num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32) supervision_segments = torch.cat( [sequence_idx, start_frame, num_frames], dim=1) supervision_segments = supervision_segments.to(torch.int32) dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments) lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec, 20.0, self.output_beam_size, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) scores = best_paths.get_tot_scores(use_double_scores=True, log_semiring=False).tolist() hyps = get_texts(best_paths) assert len(scores) == len(hyps) results = [] for token_int, score in zip(hyps, scores): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None results.append((text, token, token_int, score)) assert check_return_type(results) return results
def decode( dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable, ): num_batches = None try: num_batches = len(dataloader) except TypeError: pass num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch["inputs"] supervisions = batch["supervisions"] supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["num_frames"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions["text"] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] blank_bias = -3.0 nnet_output[:, :, 0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert HLG.is_cuda() assert ( HLG.device == nnet_output.device ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(" ") results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, number of cuts processed until now is {num_cuts}" ) num_cuts += len(texts) return results
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats) b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops( inverted_lats) print('num_arcs after pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect( rescoring_lats.to('cpu'))).to(device) inverted_rescoring_lats = k2.invert(rescoring_lats) # inverted rescoring_lats has phone IDs as labels # and word IDs as aux_labels. inverted_rescoring_lats = k2.remove_epsilon_self_loops( inverted_rescoring_lats) best_paths = k2.shortest_path(inverted_rescoring_lats, use_double_scores=True) return best_paths
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] ) -> Dict[str, k2.Fsa]: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. lm_scale_list: A list containing lm_scale values. Returns: A dict of FsaVec, whose key is a lm_scale and the value represents the best decoding path for each sequence in the lattice. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # We will use lm_scores from G, so remove lats.lm_scores here del lats.lm_scores assert hasattr(lats, 'lm_scores') is False # lats.scores = scores / lm_scale # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) print('num_arcs after pruning: ', inverted_lats.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device)) # inv_lats has phone IDs as labels # and word IDs as aux_labels. inv_lats = k2.invert(rescoring_lats) ans = dict() # # The following implements # scores = (scores - lm_scores)/lm_scale + lm_scores # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) # saved_scores = inv_lats.scores.clone() for lm_scale in lm_scale_list: am_scores = saved_scores - inv_lats.lm_scores am_scores /= lm_scale inv_lats.scores = am_scores + inv_lats.lm_scores best_paths = k2.shortest_path(inv_lats, use_double_scores=True) key = f'lm_scale_{lm_scale}' ans[key] = best_paths return ans
def decode_one_batch(batch: Dict[str, Any], model: AcousticModel, HLG: k2.Fsa, output_beam_size: float, num_paths: int, use_whole_lattice: bool, G: Optional[k2.Fsa] = None) -> Dict[str, List[List[int]]]: ''' Decode one batch and return the result in a dict. The dict has the following format: - key: It indicates the setting used for decoding. For example, if no rescoring is used, the key is the string `no_rescore`. If LM rescoring is used, the key is the string `lm_scale_xxx`, where `xxx` is the value of `lm_scale`. An example key is `lm_scale_0.7` - value: It contains the decoding result. `len(value)` equals to batch size. `value[i]` is the decoding result for the i-th utterance in the given batch. Args: batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. model: The neural network model. HLG: The decoding graph. output_beam_size: Size of the beam for pruning. use_whole_lattice: If True, `G` must not be None and it will use whole lattice for LM rescoring. If False and if `G` is not None, then `num_paths` must be positive and it will use n-best list for LM rescoring. num_paths: It specifies the size of `n` in n-best list decoding. G: The LM. If it is None, no rescoring is used. Otherwise, LM rescoring is used. It supports two types of LM rescoring: n-best list rescoring and whole lattice rescoring. `use_whole_lattice` specifies which type to use. Returns: Return the decoding result. See above description for the format of the returned dict. ''' device = HLG.device feature = batch['inputs'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] supervisions = batch['supervisions'] nnet_output, _, _ = model(feature, supervisions) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) if G is None: if num_paths > 1: best_paths = nbest_decoding(lattices, num_paths) key = f'no_rescore-{num_paths}' else: key = 'no_rescore' best_paths = k2.shortest_path(lattices, use_double_scores=True) hyps = get_texts(best_paths, indices) return {key: hyps} lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if use_whole_lattice: best_paths_dict = rescore_with_whole_lattice(lattices, G, lm_scale_list) else: best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths, lm_scale_list) # best_paths_dict is a dict # - key: lm_scale_xxx, where xxx is the value of lm_scale. An example # key is lm_scale_1.2 # - value: it is the best path obtained using the corresponding lm scale # from the dict key. ans = dict() for lm_scale_str, best_paths in best_paths_dict.items(): hyps = get_texts(best_paths, indices) ans[lm_scale_str] = hyps return ans
def levenshtein_alignment( refs: Fsa, hyps: Fsa, hyp_to_ref_map: torch.Tensor, sorted_match_ref: bool = False, ) -> Fsa: '''Get the levenshtein alignment of two FsaVecs This function supports both CPU and GPU. But it is very slow on CPU. Args: refs: An FsaVec (must have 3 axes, i.e., `len(refs.shape) == 3`. It is the output Fsa of the :func:`levenshtein_graph`. hyps: An FsaVec (must have 3 axes) on the same device as `refs`. It is the output Fsa of the :func:`levenshtein_graph`. hyp_to_ref_map: A 1-D torch.Tensor with dtype torch.int32 on the same device as `refs`. Map from FSA-id in `hpys` to the corresponding FSA-id in `refs` that we want to get levenshtein alignment with. E.g. might be an identity map, or all-to-zero, or something the user chooses. Requires - `hyp_to_ref_map.shape[0] == hyps.shape[0]` - `0 <= hyp_to_ref_map[i] < refs.shape[0]` sorted_match_ref: If true, the arcs of refs must be sorted by label (checked by calling code via properties), and we'll use a matching approach that requires this. Returns: Returns an FsaVec containing the alignment information and satisfing `ans.Dim0() == hyps.Dim0()`. Two attributes named `ref_labels` and `hyp_labels` will be added to the returned FsaVec. `ref_labels` contains the aligned sequences of refs and `hyp_labels` contains the aligned sequences of hyps. You can get the levenshtein distance by calling `get_tot_scores` on the returned FsaVec. Examples: >>> hyps = k2.levenshtein_graph([[1, 2, 3], [1, 3, 3, 2]]) >>> refs = k2.levenshtein_graph([[1, 2, 4]]) >>> alignment = k2.levenshtein_alignment( refs, hyps, hyp_to_ref_map=torch.tensor([0, 0], dtype=torch.int32), sorted_match_ref=True) >>> alignment.labels tensor([ 1, 2, 0, -1, 1, 0, 0, 0, -1], dtype=torch.int32) >>> alignment.ref_labels tensor([ 1, 2, 4, -1, 1, 2, 4, 0, -1], dtype=torch.int32) >>> alignment.hyp_labels tensor([ 1, 2, 3, -1, 1, 3, 3, 2, -1], dtype=torch.int32) >>> -alignment.get_tot_scores( use_double_scores=False, log_semiring=False)) tensor([1., 3.]) ''' assert hasattr(refs, "aux_labels") assert hasattr(hyps, "aux_labels") hyps.rename_tensor_attribute_("aux_labels", "hyp_labels") lattice = k2.intersect_device( refs, hyps, b_to_a_map=hyp_to_ref_map, sorted_match_a=sorted_match_ref) lattice = k2.remove_epsilon_self_loops(lattice) alignment = k2.shortest_path(lattice, use_double_scores=True).invert_() alignment.rename_tensor_attribute_("labels", "ref_labels") alignment.rename_tensor_attribute_("aux_labels", "labels") alignment.scores -= getattr( alignment, "__ins_del_score_offset_internal_attr_") return alignment
def decode( self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, return_lattices: bool = False, return_ilabels: bool = False, output_aligned: bool = True, ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]: if self.decoding_graph is None: self.decoding_graph = self.base_graph if self.blank != 0: # rearrange log_probs to put blank at the first place # and shift targets to emulate blank = 0 log_probs, _ = make_blank_first(self.blank, log_probs, None) supervisions, order = create_supervision(log_probs_length) if self.decoding_graph.shape[0] > 1: self.decoding_graph = k2.index_fsa(self.decoding_graph, order).to(device=log_probs.device) if log_probs.device != self.device: self.to(log_probs.device) dense_fsa_vec = ( prep_padded_densefsavec(log_probs, supervisions) if self.pad_fsavec else k2.DenseFsaVec(log_probs, supervisions) ) if self.intersect_pruned: lats = k2.intersect_dense_pruned( a_fsas=self.decoding_graph, b_fsas=dense_fsa_vec, search_beam=self.intersect_conf.search_beam, output_beam=self.intersect_conf.output_beam, min_active_states=self.intersect_conf.min_active_states, max_active_states=self.intersect_conf.max_active_states, ) else: indices = torch.zeros(dense_fsa_vec.dim0(), dtype=torch.int32, device=self.device) dec_graphs = ( k2.index_fsa(self.decoding_graph, indices) if self.decoding_graph.shape[0] == 1 else self.decoding_graph ) lats = k2.intersect_dense(dec_graphs, dense_fsa_vec, self.intersect_conf.output_beam) if self.pad_fsavec: shift_labels_inpl([lats], -1) self.decoding_graph = None if return_lattices: lats = k2.index_fsa(lats, invert_permutation(order).to(device=log_probs.device)) if self.blank != 0: # change only ilabels # suppose self.blank == self.num_classes - 1 lats.labels = torch.where(lats.labels == 0, self.blank, lats.labels - 1) return lats else: shortest_path_fsas = k2.index_fsa( k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device), ) shortest_paths = [] probs = [] # direct iterating does not work as expected for i in range(shortest_path_fsas.shape[0]): shortest_path_fsa = shortest_path_fsas[i] labels = ( shortest_path_fsa.labels[:-1].to(dtype=torch.long) if return_ilabels else shortest_path_fsa.aux_labels[:-1].to(dtype=torch.long) ) if self.blank != 0: # suppose self.blank == self.num_classes - 1 labels = torch.where(labels == 0, self.blank, labels - 1) if not return_ilabels and not output_aligned: labels = labels[labels != self.blank] shortest_paths.append(labels[::2] if self.pad_fsavec else labels) probs.append(get_arc_weights(shortest_path_fsa)[:-1].to(device=log_probs.device).exp()) return shortest_paths, probs
def test_fsa_vec(self): # best path: # states: 0 -> 1 -> 3 -> 7 -> 9 # arcs: 1 -> 3 -> 5 -> 10 s1 = ''' 0 4 1 1 0 1 1 1 1 2 1 2 1 3 1 3 2 7 1 4 3 7 1 5 4 6 1 2 4 8 1 3 5 9 -1 4 6 9 -1 3 7 9 -1 5 8 9 -1 6 9 ''' # best path: # states: 0 -> 2 -> 3 -> 4 -> 5 # arcs: 1 -> 4 -> 5 -> 7 s2 = ''' 0 1 1 1 0 2 2 6 1 2 3 3 1 3 4 2 2 3 5 4 3 4 6 3 3 5 -1 2 4 5 -1 0 5 ''' # best path: # states: 0 -> 2 -> 3 # arcs: 1 -> 3 s3 = ''' 0 1 1 10 0 2 2 100 1 3 -1 3.5 2 3 -1 5.5 3 ''' for device in self.devices: fsa1 = k2.Fsa.from_str(s1).to(device) fsa2 = k2.Fsa.from_str(s2).to(device) fsa3 = k2.Fsa.from_str(s3).to(device) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa3.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2, fsa3]) assert fsa_vec.shape == (3, None, None) best_path = k2.shortest_path(fsa_vec, use_double_scores=False) # we recompute the total_scores for backprop total_scores = best_path.scores.sum() total_scores.backward() fsa1_best_arc_indexes = torch.tensor([1, 3, 5, 10], device=device) assert torch.all( torch.eq(fsa1.scores.grad[fsa1_best_arc_indexes], torch.ones(4, device=device))) assert fsa1.scores.grad.sum() == 4 fsa2_best_arc_indexes = torch.tensor([1, 4, 5, 7], device=device) assert torch.all( torch.eq(fsa2.scores.grad[fsa2_best_arc_indexes], torch.ones(4, device=device))) assert fsa2.scores.grad.sum() == 4 fsa3_best_arc_indexes = torch.tensor([1, 3], device=device) assert torch.all( torch.eq(fsa3.scores.grad[fsa3_best_arc_indexes], torch.ones(2, device=device))) assert fsa3.scores.grad.sum() == 2
def __call__( self, batch: Dict[str, Union[torch.Tensor, np.ndarray]] ) -> List[Tuple[Optional[str], List[str], List[int], float]]: """Inference Args: batch: Input speech data and corresponding lengths Returns: text, token, token_int, hyp """ assert check_argument_types() if isinstance(batch["speech"], np.ndarray): batch["speech"] = torch.tensor(batch["speech"]) if isinstance(batch["speech_lengths"], np.ndarray): batch["speech_lengths"] = torch.tensor(batch["speech_lengths"]) # a. To device batch = to_device(batch, device=self.device) # b. Forward Encoder # enc: [N, T, C] enc, encoder_out_lens = self.asr_model.encode(**batch) # logp_encoder_output: [N, T, C] logp_encoder_output = torch.nn.functional.log_softmax( self.asr_model.ctc.ctc_lo(enc), dim=2 ) # It maybe useful to tune blank_bias. # The valid range of blank_bias is [-inf, 0] logp_encoder_output[:, :, 0] += self.blank_bias batch_size = encoder_out_lens.size(0) sequence_idx = torch.arange(0, batch_size).unsqueeze(0).t().to(torch.int32) start_frame = torch.zeros([batch_size], dtype=torch.int32).unsqueeze(0).t() num_frames = encoder_out_lens.cpu().unsqueeze(0).t().to(torch.int32) supervision_segments = torch.cat([sequence_idx, start_frame, num_frames], dim=1) supervision_segments = supervision_segments.to(torch.int32) # An introduction to DenseFsaVec: # https://k2-fsa.github.io/k2/core_concepts/index.html#dense-fsa-vector # It could be viewed as a fsa-type lopg_encoder_output, # whose weight on the arcs are initialized with logp_encoder_output. # The goal of converting tensor-type to fsa-type is using # fsa related functions in k2. e.g. k2.intersect_dense_pruned below dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments) # The term "intersect" is similar to "compose" in k2. # The differences is are: # for "compose" functions, the composition involves # mathcing output label of a.fsa and input label of b.fsa # while for "intersect" functions, the composition involves # matching input label of a.fsa and input label of b.fsa # Actually, in compose functions, b.fsa is inverted and then # a.fsa and inv_b.fsa are intersected together. # For difference between compose and interset: # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/fsa_algo.py#L308 # For definition of k2.intersect_dense_pruned: # https://github.com/k2-fsa/k2/blob/master/k2/python/k2/autograd.py#L648 lattices = k2.intersect_dense_pruned( self.decode_graph, dense_fsa_vec, self.search_beam_size, self.output_beam_size, self.min_active_states, self.max_active_states, ) # lattices.scores is the sum of decode_graph.scores(a.k.a. lm weight) and # dense_fsa_vec.scores(a.k.a. am weight) on related arcs. # For ctc decoding graph, lattices.scores only store am weight # since the decoder_graph only define the ctc topology and # has no lm weight on its arcs. # While for 3-gram decoding, whose graph is converted from language models, # lattice.scores contains both am weights and lm weights # # It maybe useful to tune lattice.scores # The valid range of lattice_weight is [0, inf) # The lattice_weight will affect the search of k2.random_paths lattices.scores *= self.lattice_weight results = [] if self.use_nbest_rescoring: ( am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits, ) = nbest_am_lm_scores( lattices, self.num_paths, self.device, self.nbest_batch_size ) ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device) max_token_length = max(ys_pad_lens) ys_pad_list = [] for hyp in token_ids: ys_pad_list.append( torch.cat( [ torch.tensor(hyp, dtype=torch.long), torch.tensor( [self.asr_model.ignore_id] * (max_token_length.item() - len(hyp)), dtype=torch.long, ), ] ) ) ys_pad = ( torch.stack(ys_pad_list).to(torch.long).to(self.device) ) # [batch, max_token_length] encoder_out = enc.index_select(0, path_to_seq_map.to(torch.long)).to( self.device ) # [batch, T, dim] encoder_out_lens = encoder_out_lens.index_select( 0, path_to_seq_map.to(torch.long) ).to( self.device ) # [batch] decoder_scores = -self.asr_model.batchify_nll( encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size ) # padded_value for nnlm is 0 ys_pad[ys_pad == self.asr_model.ignore_id] = 0 nnlm_nll, x_lengths = self.lm.batchify_nll( ys_pad, ys_pad_lens, self.nll_batch_size ) nnlm_scores = -nnlm_nll.sum(dim=1) batch_tot_scores = ( self.am_weight * am_scores + self.decoder_weight * decoder_scores + self.nnlm_weight * nnlm_scores ) split_size = indices_to_split_size( seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0) ) batch_tot_scores = torch.split( batch_tot_scores, split_size, ) hyps = [] scores = [] processed_seqs = 0 for tot_scores in batch_tot_scores: if tot_scores.nelement() == 0: # the last element by torch.tensor_split may be empty # e.g. # torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4])) # (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64)) break best_seq_idx = processed_seqs + torch.argmax(tot_scores) assert best_seq_idx < len(token_ids) best_token_seqs = token_ids[best_seq_idx] processed_seqs += tot_scores.nelement() hyps.append(best_token_seqs) scores.append(tot_scores.max().item()) assert len(hyps) == len(split_size) else: best_paths = k2.shortest_path(lattices, use_double_scores=True) scores = best_paths.get_tot_scores( use_double_scores=True, log_semiring=False ).tolist() hyps = get_texts(best_paths) assert len(scores) == len(hyps) for token_int, score in zip(hyps, scores): # For decoding methods nbest_rescoring and ctc_decoding # hyps stores token_index, which is lattice.labels. # convert token_id to text with self.tokenizer token = self.converter.ids2tokens(token_int) assert self.tokenizer is not None text = self.tokenizer.tokens2text(token) results.append((text, token, token_int, score)) assert check_return_type(results) return results
def decode( dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HCLG: Fsa, ): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] # a list of pair [ref_labels, hyp_labels] for batch_idx, batch in enumerate(dataloader): feature = batch["inputs"] # (N, T, C) supervisions = batch["supervisions"] feature = feature.to(device) # Since we are decoding with a k2 graph here, we need to create appropriate # supervisions. The segments need to be ordered in decreasing order of # length (although in our case all segments are of same length) supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["duration"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert HLG.is_cuda() assert ( HCLG.device == nnet_output.device ), f"Check failed: HCLG.device ({HCLG.device}) == nnet_output.device ({nnet_output.device})" lattices = k2.intersect_dense_pruned(HCLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == supervisions["is_voice"].shape[0] # best_paths is an FsaVec, and each of its FSAs is a linear FSA references = supervisions["is_voice"][indices] for i in range(references.shape[0]): ref = references[i, :] hyp = k2.arc_sort( best_paths[i]).arcs_as_tensor()[:-1, 2].detach().cpu() assert ( ref.shape[0] == hyp.shape[0] ), "reference and hypothesis have unequal number of frames, {} vs. {}".format( ref.shape[0], hyp.shape[0]) results.append((supervisions["cut"][indices[i]], ref, hyp)) if batch_idx % 10 == 0: logging.info( "batch {}, cuts processed until now is {}/{} ({:.6f}%)".format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100, )) num_cuts += supervisions["is_voice"].shape[0] return results
def __call__( self, speech: Union[torch.Tensor, np.ndarray] ) -> List[Tuple[Optional[str], List[str], List[int], float]]: """Inference Args: data: Input speech data Returns: text, token, token_int, hyp """ assert check_argument_types() # Input as audio signal if isinstance(speech, np.ndarray): speech = torch.tensor(speech) # data: (Nsamples,) -> (1, Nsamples) speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) # lenghts: (1,) lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) batch = {"speech": speech, "speech_lengths": lengths} # a. To device batch = to_device(batch, device=self.device) # b. Forward Encoder # enc: [N, T, C] enc, _ = self.asr_model.encode(**batch) assert len(enc) == 1, len(enc) # logp_encoder_output: [N, T, C] logp_encoder_output = torch.nn.functional.log_softmax( self.asr_model.ctc.ctc_lo(enc), dim=2) # TODO(Liyong Guo): Support batch decoding. # Following statement only support batch_size == 1 supervision_segments = torch.tensor([[0, 0, enc.shape[1]]], dtype=torch.int32) indices = torch.tensor([0]) dense_fsa_vec = k2.DenseFsaVec(logp_encoder_output, supervision_segments) lattices = k2.intersect_dense_pruned(self.decode_graph, dense_fsa_vec, 20.0, self.output_beam_size, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) scores = best_paths.get_tot_scores(use_double_scores=True, log_semiring=False).tolist() hyps = get_texts(best_paths, indices) # TODO(Liyong Guo): Support batch decoding. now batch_size == 1. assert len(scores) == 1 assert len(scores) == len(hyps) results = [] for token_int, score in zip(hyps, scores): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) if self.tokenizer is not None: text = self.tokenizer.tokens2text(token) else: text = None results.append((text, token, token_int, score)) assert check_return_type(results) return results