def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys]) transcripts_lexicon = k2.intersect(transcripts, L_inv) transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon)) transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def decode(dataloader: torch.utils.data.DataLoader, model: None, device: Union[str, torch.device], ctc_topo: None, numericalizer=None, num_paths=-1, output_beam_size: float=8): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] for batch_idx, batch in enumerate(dataloader): assert isinstance(batch, dict), type(batch) feature = batch['inputs'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output = nnet_output.permute(0, 2, 1) # TODO(Liyong Guo): Tune this bias # blank_bias = 0.0 # nnet_output[:, :, 0] += blank_bias with torch.no_grad(): dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) lattices = k2.intersect_dense_pruned(ctc_topo, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) best_paths = k2.shortest_path(lattices, use_double_scores=True) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): pieces = [numericalizer.tokens_list[token_id] for token_id in hyps[i]] hyp_words = numericalizer.tokenizer.DecodePieces(pieces).split(' ') ref_words = texts[i].split(' ') results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: logging.info( 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( batch_idx, num_cuts, tot_num_cuts, float(num_cuts) / tot_num_cuts * 100)) num_cuts += len(texts) return results
def decode( self, cuts: Union[AnyCut, CutSet]) -> List[Tuple[List[str], List[str]]]: """ Perform decoding with an n-gram language model (HLG graph). Doesn't support rescoring at this time. """ if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) word_results = [] # Hacky way to get batch quickly... we may need to improve on this. batch = K2SpeechRecognitionDataset(cuts, input_strategy=OnTheFlyFeatures( self.extractor), check_inputs=False)[list(cuts.ids)] features = batch['inputs'].permute(0, 2, 1).to( self.device) # (B, T, F) -> (B, F, T) supervision_segments, texts = encode_supervisions( batch['supervisions']) # Forward pass through the acoustic model posteriors, _, _ = self.model(features) posteriors = posteriors.permute(0, 2, 1) # (B, F, T) -> (B, T, F) # Wrapping into k2 "dense FSA" (representing PPG as a dense graph) dense_fsa_vec = k2.DenseFsaVec(posteriors, supervision_segments) # The actual decoding starts here: # First, we intersect the HLG and the PPG # with default pruning/beam search params from snowfall # The result is a batch of graphs (lattices) lattices = k2.intersect_dense_pruned(self.HLG, dense_fsa_vec, 20.0, 8, 30, 10000) # ... then we find the shortest paths in the lattices ... best_paths = k2.shortest_path(lattices, use_double_scores=True) # ... and convert them to words with a convenience wrapper from snowfall hyps = get_texts(best_paths, torch.arange(len(texts))) # Here we read out the words from the best path graphs for i in range(len(texts)): hyp_words = [self.lexicon.words.get(x) for x in hyps[i]] ref_words = texts[i].split(' ') word_results.append((ref_words, hyp_words)) return word_results
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) device = L_inv.device transcripts = k2.create_fsa_vec( [k2.linear_fsa(x, device=device) for x in ys]) transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) transcripts_lexicon = k2.intersect(L_inv, transcripts_with_self_loops, treat_epsilons_specially=False) # Don't call invert_() above because we want to return phone IDs, # which is the `aux_labels` of transcripts_lexicon transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.top_sort(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def decode( dataloader: torch.utils.data.DataLoader, model: AcousticModel, device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable, ): num_batches = None try: num_batches = len(dataloader) except TypeError: pass num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) for batch_idx, batch in enumerate(dataloader): feature = batch["inputs"] supervisions = batch["supervisions"] supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["num_frames"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions["text"] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] blank_bias = -3.0 nnet_output[:, :, 0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) # assert HLG.is_cuda() assert ( HLG.device == nnet_output.device ), f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0) best_paths = k2.shortest_path(lattices, use_double_scores=True) assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) for i in range(len(texts)): hyp_words = [symbols.get(x) for x in hyps[i]] ref_words = texts[i].split(" ") results.append((ref_words, hyp_words)) if batch_idx % 10 == 0: batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}" logging.info( f"batch {batch_str}, number of cuts processed until now is {num_cuts}" ) num_cuts += len(texts) return results
def decode_one_batch(batch: Dict[str, Any], model: AcousticModel, HLG: k2.Fsa, output_beam_size: float, num_paths: int, use_whole_lattice: bool, G: Optional[k2.Fsa] = None) -> Dict[str, List[List[int]]]: ''' Decode one batch and return the result in a dict. The dict has the following format: - key: It indicates the setting used for decoding. For example, if no rescoring is used, the key is the string `no_rescore`. If LM rescoring is used, the key is the string `lm_scale_xxx`, where `xxx` is the value of `lm_scale`. An example key is `lm_scale_0.7` - value: It contains the decoding result. `len(value)` equals to batch size. `value[i]` is the decoding result for the i-th utterance in the given batch. Args: batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. model: The neural network model. HLG: The decoding graph. output_beam_size: Size of the beam for pruning. use_whole_lattice: If True, `G` must not be None and it will use whole lattice for LM rescoring. If False and if `G` is not None, then `num_paths` must be positive and it will use n-best list for LM rescoring. num_paths: It specifies the size of `n` in n-best list decoding. G: The LM. If it is None, no rescoring is used. Otherwise, LM rescoring is used. It supports two types of LM rescoring: n-best list rescoring and whole lattice rescoring. `use_whole_lattice` specifies which type to use. Returns: Return the decoding result. See above description for the format of the returned dict. ''' device = HLG.device feature = batch['inputs'] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] supervisions = batch['supervisions'] nnet_output, _, _ = model(feature, supervisions) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, output_beam_size, 30, 10000) if G is None: if num_paths > 1: best_paths = nbest_decoding(lattices, num_paths) key = f'no_rescore-{num_paths}' else: key = 'no_rescore' best_paths = k2.shortest_path(lattices, use_double_scores=True) hyps = get_texts(best_paths, indices) return {key: hyps} lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] if use_whole_lattice: best_paths_dict = rescore_with_whole_lattice(lattices, G, lm_scale_list) else: best_paths_dict = rescore_with_n_best_list(lattices, G, num_paths, lm_scale_list) # best_paths_dict is a dict # - key: lm_scale_xxx, where xxx is the value of lm_scale. An example # key is lm_scale_1.2 # - value: it is the best path obtained using the corresponding lm scale # from the dict key. ans = dict() for lm_scale_str, best_paths in best_paths_dict.items(): hyps = get_texts(best_paths, indices) ans[lm_scale_str] = hyps return ans