def test_aux_ragged(self): for device in self.devices: s = ''' 0 1 1 0.1 0 2 2 0.2 1 3 3 0.3 2 3 4 0.6 3 4 -1 0.7 4 ''' # https://git.io/JqNiR fsa = k2.Fsa.from_str(s).to(device) fsa.aux_labels = k2.RaggedInt('[[2 3] [3 4] [] [5] [-1]]').to( device) fsa.tensor_attr1 = torch.tensor([1, 2, 3, 4, 5]).to(device) # https://git.io/JqNiw ans = k2.invert(fsa) assert torch.all( torch.eq(ans.tensor_attr1, torch.tensor([1, 2, 0, 0, 3, 4, 5], device=device))) assert torch.all( torch.eq(ans.aux_labels, torch.tensor([1, 2, 0, 0, 3, 4, -1], device=device))) assert torch.all( torch.eq(ans.labels, torch.tensor([2, 3, 3, 4, 0, 5, -1], device=device)))
def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, path_to_seq_map: torch.Tensor) -> torch.Tensor: '''Compute AM scores of n-best lists (represented as word_fsas). Args: lats: An FsaVec, which is the output of `k2.intersect_dense_pruned`. It must have the attribute `lm_scores`. word_fsas_with_epsilon_loops: An FsaVec representing a n-best list. Note that it has been processed by `k2.add_epsilon_self_loops`. path_to_seq_map: A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to. path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). Returns: Return a 1-D torch.Tensor containing the AM scores of each path. `ans.numel() == word_fsas_with_epsilon_loops.shape[0]` ''' device = lats.device assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') # k2.compose() currently does not support b_to_a_map. To void # replicating `lats`, we use k2.intersect_device here. # # lats has phone IDs as `labels` and word IDs as aux_labels, so we # need to invert it here. inverted_lats = k2.invert(lats) # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor) # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes) # Remove its `aux_labels` since it is not needed in the # following computation del inverted_lats.aux_labels inverted_lats = k2.arc_sort(inverted_lats) am_path_lats = _intersect_device(inverted_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device) # The `scores` of every arc consists of `am_scores` and `lm_scores` am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores am_scores = am_path_lats.get_tot_scores(True, True) return am_scores
def test_aux_as_tensor(self): s = ''' 0 1 1 1 0 0 1 0 2 0 0 3 2 3 0 1 2 3 4 0 1 3 4 5 0 2 1 5 6 0 2 5 -1 -1 0 3 1 6 7 0 4 5 -1 -1 0 5 ''' fsa = k2.Fsa.from_str(s, num_aux_labels=1) assert fsa.device.type == 'cpu' dest = k2.invert(fsa) print(dest)
def test_composition_equivalence(self): index = _generate_fsa_vec() index = k2.arc_sort(k2.connect(k2.remove_epsilon(index))) src = _generate_fsa_vec() replace = k2.replace_fsa(src, index, 1) replace = k2.top_sort(replace) f_fsa = _construct_f(src) f_fsa = k2.arc_sort(f_fsa) intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True) intersect = k2.invert(intersect) intersect = k2.top_sort(intersect) delattr(intersect, 'aux_labels') assert k2.is_rand_equivalent(replace, intersect, log_semiring=True, delta=1e-3)
def _construct_f(fsa_vec: k2.Fsa) -> k2.Fsa: num_fsa = fsa_vec.shape[0] union = k2.union(fsa_vec) union.aux_labels = torch.zeros(union.num_arcs) union.aux_labels[0:num_fsa] = torch.tensor(list(range(1, 1 + num_fsa)), dtype=torch.int32) union_str = k2.to_str_simple(union) states_num = union.shape[0] new_str_array = [] new_str_array.append("0 {} -1 0 0".format(states_num - 1)) for line in union_str.strip().split("\n"): tokens = line.strip().split(" ") if len(tokens) == 5: tokens[1] = '0' if int(tokens[1]) == states_num - 1 else tokens[1] tokens[2] = '0' if int(tokens[2]) == -1 else tokens[2] new_str_array.append(" ".join(tokens)) new_str = "\n".join(new_str_array) new_fsa = k2.Fsa.from_str(new_str, num_aux_labels=1) new_fsa_invert = k2.invert(new_fsa) return new_fsa_invert
def test_aux_as_ragged(self): s = ''' 0 1 1 0 0 1 0 0 0 3 2 0 1 2 3 0 1 3 4 0 2 1 5 0 2 5 -1 0 3 1 6 0 4 5 -1 0 5 ''' fsa = k2.Fsa.from_str(s) assert fsa.device.type == 'cpu' aux_row_splits = torch.tensor([0, 2, 3, 3, 6, 6, 7, 8, 10, 11], dtype=torch.int32) aux_shape = k2.ragged.create_ragged_shape2(aux_row_splits, None, 11) aux_values = torch.tensor([1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1], dtype=torch.int32) fsa.aux_labels = k2.RaggedInt(aux_shape, aux_values) dest = k2.invert(fsa) print(dest) # will print aux_labels as well
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats) b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops( inverted_lats) print('num_arcs after pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect( rescoring_lats.to('cpu'))).to(device) inverted_rescoring_lats = k2.invert(rescoring_lats) # inverted rescoring_lats has phone IDs as labels # and word IDs as aux_labels. inverted_rescoring_lats = k2.remove_epsilon_self_loops( inverted_rescoring_lats) best_paths = k2.shortest_path(inverted_rescoring_lats, use_double_scores=True) return best_paths
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] ) -> Dict[str, k2.Fsa]: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. lm_scale_list: A list containing lm_scale values. Returns: A dict of FsaVec, whose key is a lm_scale and the value represents the best decoding path for each sequence in the lattice. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # We will use lm_scores from G, so remove lats.lm_scores here del lats.lm_scores assert hasattr(lats, 'lm_scores') is False # lats.scores = scores / lm_scale # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) print('num_arcs after pruning: ', inverted_lats.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device)) # inv_lats has phone IDs as labels # and word IDs as aux_labels. inv_lats = k2.invert(rescoring_lats) ans = dict() # # The following implements # scores = (scores - lm_scores)/lm_scale + lm_scores # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) # saved_scores = inv_lats.scores.clone() for lm_scale in lm_scale_list: am_scores = saved_scores - inv_lats.lm_scores am_scores /= lm_scale inv_lats.scores = am_scores + inv_lats.lm_scores best_paths = k2.shortest_path(inv_lats, use_double_scores=True) key = f'lm_scale_{lm_scale}' ans[key] = best_paths return ans
import k2 s = ''' 0 1 2 10 0.1 1 2 -1 -1 0.2 2 ''' fsa = k2.Fsa.from_str(s) inverted_fsa = k2.invert(fsa) fsa.draw('before_invert.svg', title='before invert') inverted_fsa.draw('after_invert.svg', title='after invert')
def nbest_decoding(lats: k2.Fsa, num_paths: int): ''' (Ideas of this function are from Dan) It implements something like CTC prefix beam search using n-best lists The basic idea is to first extra n-best paths from the given lattice, build a word seqs from these paths, and compute the total scores of these sequences in the log-semiring. The one with the max score is used as the decoding output. ''' # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. word_seqs = k2.index(lats.aux_labels, paths) # Note: the above operation supports also the case when # lats.aux_labels is a ragged tensor. In that case, # `remove_axis=True` is used inside the pybind11 binding code, # so the resulting `word_seqs` still has 3 axes, like `paths`. # The 3 axes are [seq][path][word] # Remove epsilons and -1 from word_seqs word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() unique_word_seqs, _, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=False, need_new2old_indexes=True) # Note: unique_word_seqs still has the same axes as word_seqs seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) # lats has phone IDs as labels and word IDs as aux_labels. # inv_lats has word IDs as labels and phone IDs as aux_labels inv_lats = k2.invert(lats) inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted path_lats = k2.intersect_device(inv_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) # path_lats has word IDs as labels and phone IDs as aux_labels path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device)) tot_scores = path_lats.get_tot_scores(True, True) # RaggedFloat currently supports float32 only. # We may bind Ragged<double> as RaggedDouble if needed. ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `paths`, we use `new2old` # here to convert argmax_indexes to the indexes into `paths`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 best_path_indexes = k2.index(new2old, argmax_indexes) paths_2axes = k2.ragged.remove_axis(paths, 0) # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos] best_paths = k2.index(paths_2axes, best_path_indexes) # labels is a k2.RaggedInt with 2 axes [path][phone_id] # Note that it contains -1s. labels = k2.index(lats.labels.contiguous(), best_paths) labels = k2.ragged.remove_values_eq(labels, -1) # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so # aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index(lats.aux_labels, best_paths.values()) best_path_fsas = k2.linear_fsa(labels) best_path_fsas.aux_labels = aux_labels return best_path_fsas
def whole_lattice_rescoring(lats: Fsa, G_with_epsilon_loops: Fsa) -> Fsa: '''Rescore the 1st pass lattice with an LM. In general, the G in HLG used to obtain `lats` is a 3-gram LM. This function replaces the 3-gram LM in `lats` with a 4-gram LM. Args: lats: The decoding lattice from the 1st pass. We assume it is the result of intersecting HLG with the network output. G_with_epsilon_loops: An LM. It is usually a 4-gram LM with epsilon self-loops. It should be arc sorted. Returns: Return a new lattice rescored with a given G. ''' assert len(lats.shape) == 3, f'{lats.shape}' assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None), \ f'{G_with_epsilon_loops.shape}' device = lats.device lats.scores = lats.scores - lats.lm_scores # Now lats contains only acoustic scores # We will use lm_scores from the given G, so remove lats.lm_scores here del lats.lm_scores assert hasattr(lats, 'lm_scores') is False # inverted_lats has word IDs as labels. # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt # if lats.aux_labels is a ragged tensor inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) while True: try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) break except RuntimeError as e: logging.info(f'Caught exception:\n{e}\n') # Usually, this is an OOM exception. We reduce # the size of the lattice and redo k2.intersect_device() # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here # to avoid OOM. We may need to fine tune it. logging.info(f'num_arcs before: {inverted_lats.num_arcs}') inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True) logging.info(f'num_arcs after: {inverted_lats.num_arcs}') rescoring_lats = k2.top_sort(k2.connect(rescoring_lats)) # inv_rescoring_lats has token IDs as labels # and word IDs as aux_labels. inv_rescoring_lats = k2.invert(rescoring_lats) return inv_rescoring_lats
def compute_am_scores_and_lm_scores( lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, path_to_seq_map: torch.Tensor, device: str = "cuda", batch_size: int = 500, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute AM and LM scores of n-best lists (represented as word_fsas). Args: lats: An FsaVec, which is the output of `k2.intersect_dense_pruned`. It must have the attribute `lm_scores`. word_fsas_with_epsilon_loops: An FsaVec representing a n-best list. Note that it has been processed by `k2.add_epsilon_self_loops`. path_to_seq_map: A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to. path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). batch_size: Batchify the n-best list when intersecting with inverted_lats. You could tune this to avoid GPU OOM issue or increase the GPU usage. Returns: Return a tuple of (1-D torch.Tensor, 1-D torch.Tensor) containing the AM and LM scores of each path. `am_scores.numel() == word_fsas_with_epsilon_loops.shape[0]` `lm_scores.numel() == word_fsas_with_epsilon_loops.shape[0]` """ assert len(lats.shape) == 3 # k2.compose() currently does not support b_to_a_map. To void # replicating `lats`, we use k2.intersect_device here. # # lats has phone IDs as `labels` and word IDs as aux_labels, so we # need to invert it here. inverted_lats = k2.invert(lats) # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor) # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes) # Remove its `aux_labels` since it is not needed in the # following computation del inverted_lats.aux_labels inverted_lats = k2.arc_sort(inverted_lats) am_path_lats = _intersect_device( inverted_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True, batch_size=batch_size, ) am_path_lats = k2.top_sort(k2.connect(am_path_lats)) # The `scores` of every arc consists of `am_scores` and `lm_scores` tot_score_device = "cpu" if hasattr(lats, "lm_scores"): am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores am_scores = ( am_path_lats.to(tot_score_device) .get_tot_scores(use_double_scores=True, log_semiring=False) .to(device) ) # Start to compute lm_scores am_path_lats.scores = am_path_lats.lm_scores lm_scores = ( am_path_lats.to(tot_score_device) .get_tot_scores(use_double_scores=True, log_semiring=False) .to(device) ) else: am_scores = ( am_path_lats.to(tot_score_device) .get_tot_scores(use_double_scores=True, log_semiring=False) .to(device) ) lm_scores = None return am_scores, lm_scores