def generate_nbest_list(lats: Fsa, num_paths: int) -> Nbest: '''Generate an n-best list from a lattice. Args: lats: The decoding lattice from the first pass after LM rescoring. lats is an FsaVec. It can be the return value of :func:`whole_lattice_rescoring` num_paths: Size of n for n-best list. CAUTION: After removing paths that represent the same token sequences, the number of paths in different sequences may not be equal. Return: Return an Nbest object. Note the returned FSAs don't have epsilon self-loops. ''' assert len(lats.shape) == 3 # CAUTION: We use `phones` instead of `tokens` here because # :func:`compile_HLG` uses `phones` # # Note: compile_HLG is from k2-fsa/snowfall assert hasattr(lats, 'phones') assert not hasattr(lats, 'tokens') lats.tokens = lats.phones # we use tokens instead of phones in the following code # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedTensor with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # token_seqs is a k2.RaggedTensor sharing the same shape as `paths` # but it contains token IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. # Its axes are [seq][path][token_id] token_seqs = k2.ragged.index(lats.tokens, paths) # Remove epsilons (0s) and -1 from token_seqs token_seqs = token_seqs.remove_values_leq(0) # unique_token_seqs is still a k2.RaggedTensor with axes # [seq][path]token_id]. # But then number of paths in each sequence may be different. unique_token_seqs, _, _ = token_seqs.unique(need_num_repeats=False, need_new2old_indexes=False) seq_to_path_shape = unique_token_seqs.shape.get_layer(0) # Remove the seq axis. # Now unique_token_seqs has only two axes [path][token_id] unique_token_seqs = unique_token_seqs.remove_axis(0) token_fsas = k2.linear_fsa(unique_token_seqs) return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
def nbest_am_lm_scores( lats: k2.Fsa, num_paths: int, device: str = "cuda", batch_size: int = 500, ): """Compute am scores with word_seqs Compatible with both ctc_decoding or TLG decoding. """ paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) if isinstance(lats.aux_labels, torch.Tensor): word_seqs = k2.ragged.index(lats.aux_labels.contiguous(), paths) else: # '_k2.RaggedInt' object has no attribute 'contiguous' word_seqs = lats.aux_labels.index(paths) word_seqs = word_seqs.remove_axis(word_seqs.num_axes - 2) # With ctc_decoding, word_seqs stores token_ids. # With TLG decoding, word_seqs stores word_ids. word_seqs = word_seqs.remove_values_leq(0) unique_word_seqs, num_repeats, new2old = word_seqs.unique( need_num_repeats=True, need_new2old_indexes=True ) seq_to_path_shape = unique_word_seqs.shape.get_layer(0) path_to_seq_map = seq_to_path_shape.row_ids(1) # used to split final computed tot_scores seq_to_path_splits = seq_to_path_shape.row_splits(1) unique_word_seqs = unique_word_seqs.remove_axis(0) word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) am_scores, lm_scores = compute_am_scores_and_lm_scores( lats, word_fsas_with_epsilon_loops, path_to_seq_map, device, batch_size ) token_seqs = k2.ragged.index(lats.labels.contiguous(), paths) token_seqs = token_seqs.remove_axis(0) token_ids, _ = token_seqs.index(new2old, axis=0) token_ids = token_ids.tolist() # Now remove repeated tokens and 0s and -1s. token_ids = [remove_repeated_and_leq(tokens) for tokens in token_ids] return am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits
def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, num_paths: int) -> k2.Fsa: '''Decode using n-best list with LM rescoring. `lats` is a decoding lattice, which has 3 axes. This function first extracts `num_paths` paths from `lats` for each sequence using `k2.random_paths`. The `am_scores` of these paths are computed. For each path, its `lm_scores` is computed using `G` (which is an LM). The final `tot_scores` is the sum of `am_scores` and `lm_scores`. The path with the greatest `tot_scores` within a sequence is used as the decoding output. Args: lats: An FsaVec. It can be the output of `k2.intersect_dense_pruned`. G: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. num_paths: It is the size `n` in `n-best` list. Returns: An FsaVec representing the best decoding path for each sequence in the lattice. ''' device = lats.device assert len(lats.shape) == 3 assert hasattr(lats, 'aux_labels') assert hasattr(lats, 'lm_scores') assert G.shape == (1, None, None) assert G.device == device assert hasattr(G, 'aux_labels') is False # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. word_seqs = k2.index(lats.aux_labels, paths) # Remove epsilons and -1 from word_seqs word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. # # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a seq. # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. # num_repeats.num_elements() == unique_word_seqs.num_elements() # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=True, need_new2old_indexes=True) seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map) # Now compute lm_scores b_to_a_map = torch.zeros_like(path_to_seq_map) lm_path_lats = _intersect_device(G, word_fsas_with_epsilon_loops, b_to_a_map=b_to_a_map, sorted_match_a=True) lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device) lm_scores = lm_path_lats.get_tot_scores(True, True) tot_scores = am_scores + lm_scores # Remember that we used `k2.ragged.unique_sequences` to remove repeated # paths to avoid redundant computation in `k2.intersect_device`. # Now we use `num_repeats` to correct the scores for each path. # # NOTE(fangjun): It is commented out as it leads to a worse WER # tot_scores = tot_scores * num_repeats.values() # TODO(fangjun): We may need to add `k2.RaggedDouble` ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) # Use k2.index here since argmax_indexes' dtype is torch.int32 best_path_indexes = k2.index(new2old, argmax_indexes) paths = k2.ragged.remove_axis(paths, 0) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] best_paths = k2.index(paths, best_path_indexes) # labels is a k2.RaggedInt with 2 axes [path][phone_id] # Note that it contains -1s. labels = k2.index(lats.labels.contiguous(), best_paths) labels = k2.ragged.remove_values_eq(labels, -1) # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so # aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index(lats.aux_labels, best_paths.values()) best_path_fsas = k2.linear_fsa(labels) best_path_fsas.aux_labels = aux_labels return best_path_fsas
def nbest_decoding(lats: k2.Fsa, num_paths: int): ''' (Ideas of this function are from Dan) It implements something like CTC prefix beam search using n-best lists The basic idea is to first extra n-best paths from the given lattice, build a word seqs from these paths, and compute the total scores of these sequences in the log-semiring. The one with the max score is used as the decoding output. ''' # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. word_seqs = k2.index(lats.aux_labels, paths) # Note: the above operation supports also the case when # lats.aux_labels is a ragged tensor. In that case, # `remove_axis=True` is used inside the pybind11 binding code, # so the resulting `word_seqs` still has 3 axes, like `paths`. # The 3 axes are [seq][path][word] # Remove epsilons and -1 from word_seqs word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() unique_word_seqs, _, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=False, need_new2old_indexes=True) # Note: unique_word_seqs still has the same axes as word_seqs seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) # lats has phone IDs as labels and word IDs as aux_labels. # inv_lats has word IDs as labels and phone IDs as aux_labels inv_lats = k2.invert(lats) inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted path_lats = k2.intersect_device(inv_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) # path_lats has word IDs as labels and phone IDs as aux_labels path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device)) tot_scores = path_lats.get_tot_scores(True, True) # RaggedFloat currently supports float32 only. # We may bind Ragged<double> as RaggedDouble if needed. ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `paths`, we use `new2old` # here to convert argmax_indexes to the indexes into `paths`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 best_path_indexes = k2.index(new2old, argmax_indexes) paths_2axes = k2.ragged.remove_axis(paths, 0) # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos] best_paths = k2.index(paths_2axes, best_path_indexes) # labels is a k2.RaggedInt with 2 axes [path][phone_id] # Note that it contains -1s. labels = k2.index(lats.labels.contiguous(), best_paths) labels = k2.ragged.remove_values_eq(labels, -1) # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so # aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index(lats.aux_labels, best_paths.values()) best_path_fsas = k2.linear_fsa(labels) best_path_fsas.aux_labels = aux_labels return best_path_fsas