def test(self): for device in self.devices: s = ''' 0 1 1 0.2 0 1 2 0.1 1 2 2 0.1 1 2 3 0.2 2 3 -1 0 3 ''' fsa = k2.Fsa.from_str(s) fsa_vec = k2.create_fsa_vec([fsa]) threshold_prob = 0.5 ans = k2.prune_on_arc_post(fsa_vec, threshold_prob, use_double_scores=True) expected = k2.Fsa.from_str(''' 0 1 1 0.2 1 2 3 0.2 2 3 -1 0 3 ''') assert str(ans[0]) == str(expected)
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] ) -> Dict[str, k2.Fsa]: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. lm_scale_list: A list containing lm_scale values. Returns: A dict of FsaVec, whose key is a lm_scale and the value represents the best decoding path for each sequence in the lattice. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # We will use lm_scores from G, so remove lats.lm_scores here del lats.lm_scores assert hasattr(lats, 'lm_scores') is False # lats.scores = scores / lm_scale # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) print('num_arcs after pruning: ', inverted_lats.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device)) # inv_lats has phone IDs as labels # and word IDs as aux_labels. inv_lats = k2.invert(rescoring_lats) ans = dict() # # The following implements # scores = (scores - lm_scores)/lm_scale + lm_scores # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) # saved_scores = inv_lats.scores.clone() for lm_scale in lm_scale_list: am_scores = saved_scores - inv_lats.lm_scores am_scores /= lm_scale inv_lats.scores = am_scores + inv_lats.lm_scores best_paths = k2.shortest_path(inv_lats, use_double_scores=True) key = f'lm_scale_{lm_scale}' ans[key] = best_paths return ans
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats) b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops( inverted_lats) print('num_arcs after pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect( rescoring_lats.to('cpu'))).to(device) inverted_rescoring_lats = k2.invert(rescoring_lats) # inverted rescoring_lats has phone IDs as labels # and word IDs as aux_labels. inverted_rescoring_lats = k2.remove_epsilon_self_loops( inverted_rescoring_lats) best_paths = k2.shortest_path(inverted_rescoring_lats, use_double_scores=True) return best_paths
def whole_lattice_rescoring(lats: Fsa, G_with_epsilon_loops: Fsa) -> Fsa: '''Rescore the 1st pass lattice with an LM. In general, the G in HLG used to obtain `lats` is a 3-gram LM. This function replaces the 3-gram LM in `lats` with a 4-gram LM. Args: lats: The decoding lattice from the 1st pass. We assume it is the result of intersecting HLG with the network output. G_with_epsilon_loops: An LM. It is usually a 4-gram LM with epsilon self-loops. It should be arc sorted. Returns: Return a new lattice rescored with a given G. ''' assert len(lats.shape) == 3, f'{lats.shape}' assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None), \ f'{G_with_epsilon_loops.shape}' device = lats.device lats.scores = lats.scores - lats.lm_scores # Now lats contains only acoustic scores # We will use lm_scores from the given G, so remove lats.lm_scores here del lats.lm_scores assert hasattr(lats, 'lm_scores') is False # inverted_lats has word IDs as labels. # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt # if lats.aux_labels is a ragged tensor inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) while True: try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) break except RuntimeError as e: logging.info(f'Caught exception:\n{e}\n') # Usually, this is an OOM exception. We reduce # the size of the lattice and redo k2.intersect_device() # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here # to avoid OOM. We may need to fine tune it. logging.info(f'num_arcs before: {inverted_lats.num_arcs}') inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True) logging.info(f'num_arcs after: {inverted_lats.num_arcs}') rescoring_lats = k2.top_sort(k2.connect(rescoring_lats)) # inv_rescoring_lats has token IDs as labels # and word IDs as aux_labels. inv_rescoring_lats = k2.invert(rescoring_lats) return inv_rescoring_lats