def test_no_repeated(self): # standard ctc topo and modified ctc topo # should be equivalent if there are no # repeated neighboring symbols in the transcript max_token = 3 standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) transcript = k2.linear_fsa([1, 2, 3]) standard_graph = k2.compose(standard, transcript) modified_graph = k2.compose(modified, transcript) input1 = k2.linear_fsa([1, 1, 1, 0, 0, 2, 2, 3, 3]) input2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 3, 3]) inputs = [input1, input2] for i in inputs: lattice1 = k2.intersect(standard_graph, i, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, i, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0] aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels1 = aux_labels1[:-1] # remove -1 aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, aux_labels2)) assert torch.all(torch.eq(aux_labels2, torch.tensor([1, 2, 3])))
def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. Returns: A tuple (num_graph, den_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph`. ''' assert P.is_cpu() ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_() ctc_topo_P = k2.connect(ctc_topo_P) num_graphs = k2.create_fsa_vec( [self.compile_one_and_cache(text) for text in texts]) num = k2.compose(ctc_topo_P, num_graphs) num = k2.connect(num) num = k2.arc_sort(num) den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts)) return num, den
def compile_one_and_cache(self, text: str) -> k2.Fsa: tokens = (token if token in self.words else self.oov for token in text.split(' ')) word_ids = [self.words[token] for token in tokens] fsa = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() decoding_graph = k2.arc_sort(decoding_graph) decoding_graph = k2.compose(self.ctc_topo, decoding_graph) decoding_graph = k2.connect(decoding_graph) return decoding_graph
def test_with_repeated(self): max_token = 2 standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) transcript = k2.linear_fsa([1, 2, 2]) standard_graph = k2.compose(standard, transcript) modified_graph = k2.compose(modified, transcript) # There is a blank separating 2 in the input # so standard and modified ctc topo should be equivalent input = k2.linear_fsa([1, 1, 2, 2, 0, 2, 2, 0, 0]) lattice1 = k2.intersect(standard_graph, input, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, input, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0] aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels1 = aux_labels1[:-1] # remove -1 aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, aux_labels2)) assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2]))) # There are no blanks separating 2 in the input. # The standard ctc topo requires that there must be a blank # separating 2, so lattice1 in the following is empty input = k2.linear_fsa([1, 1, 2, 2, 0, 0]) lattice1 = k2.intersect(standard_graph, input, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, input, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) assert lattice1.num_arcs == 0 # Since there are two 2s in the input and there are also two 2s # in the transcript, the final output contains only one path. # If there were more than two 2s in the input, the output # would contain more than one path aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))
def test_compose(self): s = ''' 0 1 11 1 1.0 0 2 12 2 2.5 1 3 -1 -1 0 2 3 -1 -1 2.5 3 ''' a_fsa = k2.Fsa.from_str(s, num_aux_labels=1).requires_grad_(True) s = ''' 0 1 1 1 1.0 0 2 2 3 3.0 1 2 3 2 2.5 2 3 -1 -1 2.0 3 ''' b_fsa = k2.Fsa.from_str(s, num_aux_labels=1).requires_grad_(True) ans = k2.compose(a_fsa, b_fsa, inner_labels='inner') ans = k2.connect(ans) ans = k2.create_fsa_vec([ans]) scores = ans.get_tot_scores(log_semiring=True, use_double_scores=False) # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad` # are computed using GTN. # See https://bit.ly/3heLAJq assert scores.item() == 10 scores.backward() assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.])) assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.]))
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys]) transcripts_lexicon = k2.intersect(transcripts, L_inv) transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon)) transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def test_case1(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # suppose we have four symbols: <blk>, a, b, c, d torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]).to(device) k2_activation = torch_activation.detach().clone() # (T, N, C) torch_activation = torch_activation.reshape( 1, 1, -1).requires_grad_(True) # (N, T, C) k2_activation = k2_activation.reshape(1, 1, -1).requires_grad_(True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its label is `a` targets = torch.tensor([1]).to(device) input_lengths = torch.tensor([1]).to(device) target_lengths = torch.tensor([1]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose(torch_loss, torch.tensor([1.6094379425049]).to(device)) # (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([1]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test_ragged_aux_labels(self): s1 = ''' 0 1 1 0.1 0 2 5 0.6 1 2 3 0.3 2 3 3 0.5 2 4 2 0.6 3 5 -1 0.7 4 5 -1 0.8 5 ''' s2 = ''' 0 0 2 1 1 0 1 4 3 2 0 1 6 2 2 0 2 -1 -1 0 1 1 2 5 3 1 2 -1 -1 4 2 ''' # https://git.io/JqNok fsa1 = k2.Fsa.from_str(s1) fsa1.aux_labels = k2.RaggedInt('[[2] [2 4] [5] [3] [2] [-1] [-1]]') # https://git.io/JqNaJ fsa2 = k2.Fsa.from_str(s2, num_aux_labels=1) # https://git.io/JqNon ans = k2.connect(k2.compose(fsa1, fsa2, inner_labels='phones')) assert torch.all(torch.eq(ans.labels, torch.tensor([5, 0, 2, -1]))) assert torch.all(torch.eq(ans.phones, torch.tensor([2, 4, 2, -1]))) assert str(ans.aux_labels) == str(k2.RaggedInt('[[1] [3] [5] [-1]]'))
def test_compose_inner_labels(self): s1 = ''' 0 1 1 2 0.1 0 2 0 2 0.2 1 3 3 5 0.3 2 3 5 4 0.4 3 4 3 3 0.5 3 5 2 2 0.6 4 6 -1 -1 0.7 5 6 -1 -1 0.8 6 ''' s2 = ''' 0 0 2 1 1 0 1 4 3 2 0 1 6 2 2 0 2 -1 -1 0 1 1 2 5 3 1 2 -1 -1 4 2 ''' # https://git.io/JqN2j fsa1 = k2.Fsa.from_str(s1, num_aux_labels=1) # https://git.io/JqNaJ fsa2 = k2.Fsa.from_str(s2, num_aux_labels=1) # https://git.io/JqNaT ans = k2.connect(k2.compose(fsa1, fsa2, inner_labels='phones')) assert torch.all(torch.eq(ans.labels, torch.tensor([0, 5, 2, -1]))) assert torch.all(torch.eq(ans.phones, torch.tensor([2, 4, 2, -1]))) assert torch.all(torch.eq(ans.aux_labels, torch.tensor([1, 3, 5, -1])))
def compile_one_and_cache(self, text: str) -> Fsa: tokens = (token if token in self.vocab._sym2id else self.oov for token in text.split(' ')) word_ids = [self.vocab.get(token) for token in tokens] fsa = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) return decoding_graph
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 LG = k2.add_epsilon_self_loops(LG) LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def test_case3(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # (T, N, C) torch_activation = torch.tensor([[ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]]).permute(1, 0, 2).to(device).requires_grad_(True) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `b,c` targets = torch.tensor([2, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([2, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([4.938850402832]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test_random_case1(self): # 1 sequence devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: T = torch.randint(10, 100, (1,)).item() C = torch.randint(20, 30, (1,)).item() torch_activation = torch.rand((1, T + 10, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) # [N, T, C] -> [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activation.permute(1, 0, 2), dim=-1) input_lengths = torch.tensor([T]).to(device) target_lengths = torch.randint(1, T, (1,)).to(device) targets = torch.randint(1, C - 1, (target_lengths.item(),)).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([targets.tolist()]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation.grad, k2_activation.grad, atol=1e-2)
def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, path_to_seq_map: torch.Tensor) -> torch.Tensor: '''Compute AM scores of n-best lists (represented as word_fsas). Args: lats: An FsaVec, which is the output of `k2.intersect_dense_pruned`. It must have the attribute `lm_scores`. word_fsas_with_epsilon_loops: An FsaVec representing a n-best list. Note that it has been processed by `k2.add_epsilon_self_loops`. path_to_seq_map: A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to. path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). Returns: Return a 1-D torch.Tensor containing the AM scores of each path. `ans.numel() == word_fsas_with_epsilon_loops.shape[0]` ''' device = lats.device assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') # k2.compose() currently does not support b_to_a_map. To void # replicating `lats`, we use k2.intersect_device here. # # lats has phone IDs as `labels` and word IDs as aux_labels, so we # need to invert it here. inverted_lats = k2.invert(lats) # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor) # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes) # Remove its `aux_labels` since it is not needed in the # following computation del inverted_lats.aux_labels inverted_lats = k2.arc_sort(inverted_lats) am_path_lats = _intersect_device(inverted_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device) # The `scores` of every arc consists of `am_scores` and `lm_scores` am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores am_scores = am_path_lats.get_tot_scores(True, True) return am_scores
def _generate_fsa_vec(min_num_fsas: int = 20, max_num_fsas: int = 21, acyclic: bool = True, max_symbol: int = 20, min_num_arcs: int = 10, max_num_arcs: int = 15) -> k2.Fsa: fsa = k2.random_fsa_vec(min_num_fsas, max_num_fsas, acyclic, min_num_arcs, max_num_arcs) fsa = k2.connect(fsa) while True: success = True for i in range(fsa.shape[0]): if fsa[i].shape[0] == 0: success = False break if success: break else: fsa = k2.random_fsa_vec(min_num_fsas, max_num_fsas, acyclic, min_num_arcs, max_num_arcs) fsa = k2.connect(fsa) return fsa
def test_case2(self): for device in self.devices: # (T, N, C) torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute( 1, 0, 2).to(device) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `c,c` targets = torch.tensor([3, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([3, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([7.355742931366]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def compile_one_and_cache(self, text: str) -> k2.Fsa: '''Convert transcript to an Fsa with the help of lexicon and word symbol table. Args: text: The transcript containing words separated by spaces. Returns: Return an FST corresponding to the transcript. Its `labels` are phone IDs and `aux_labels` are word IDs. ''' tokens = (token if token in self.words else self.oov for token in text.split(' ')) word_ids = [self.words[token] for token in tokens] fsa = k2.linear_fsa(word_ids) num_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() num_graph = k2.arc_sort(num_graph) return num_graph
def test(self): s = ''' 0 1 1 0.1 0 2 2 0.2 1 4 -1 0.3 3 4 -1 0.4 4 ''' fsa = k2.Fsa.from_str(s) fsa.requires_grad_(True) expected_str = '\n'.join(['0 1 1 0.1', '1 2 -1 0.3', '2']) connected_fsa = k2.connect(fsa) actual_str = k2.to_str_simple(connected_fsa) assert actual_str.strip() == expected_str loss = connected_fsa.scores.sum() loss.backward() assert torch.allclose(fsa.scores.grad, torch.tensor([1, 0, 1, 0], dtype=torch.float32))
def test_composition_equivalence(self): index = _generate_fsa_vec() index = k2.arc_sort(k2.connect(k2.remove_epsilon(index))) src = _generate_fsa_vec() replace = k2.replace_fsa(src, index, 1) replace = k2.top_sort(replace) f_fsa = _construct_f(src) f_fsa = k2.arc_sort(f_fsa) intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True) intersect = k2.invert(intersect) intersect = k2.top_sort(intersect) delattr(intersect, 'aux_labels') assert k2.is_rand_equivalent(replace, intersect, log_semiring=True, delta=1e-3)
def test_random(self): while True: fsa = k2.random_fsa(max_symbol=20, min_num_arcs=50, max_num_arcs=500) fsa = k2.arc_sort(k2.connect(k2.remove_epsilon(fsa))) prob = fsa.properties # we need non-deterministic fsa if not prob & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC: break log_semiring = False # test weight pushing tropical dest_max = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing) self.assertTrue( k2.is_rand_equivalent(fsa, dest_max, log_semiring, delta=1e-3)) # test weight pushing log dest_log = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing) self.assertTrue( k2.is_rand_equivalent(fsa, dest_log, log_semiring, delta=1e-3))
def test_compose(self): s = ''' 0 1 11 1 1.0 0 2 12 2 2.5 1 3 -1 -1 0 2 3 -1 -1 2.5 3 ''' a_fsa = k2.Fsa.from_str(s).requires_grad_(True) s = ''' 0 1 1 1 1.0 0 2 2 3 3.0 1 2 3 2 2.5 2 3 -1 -1 2.0 3 ''' b_fsa = k2.Fsa.from_str(s).requires_grad_(True) ans = k2.compose(a_fsa, b_fsa, inner_labels='inner') ans = k2.connect(ans) # Convert a single FSA to a FsaVec. # It will retain `requires_grad_` of `ans`. ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs]) scores = k2.get_tot_scores(ans, log_semiring=True, use_double_scores=False) # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad` # are computed using GTN. # See https://bit.ly/3heLAJq assert scores.item() == 10 scores.backward() assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.])) assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.])) print(ans)
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. H: An ``Fsa`` that represents a specific topology used to convert the network outputs to a sequence of phones. Typically, it's a CTC topology fst, in which when 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L = k2.arc_sort(L) G = k2.arc_sort(G) logging.info("Intersecting L and G") LG = k2.compose(L, G) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting L*G") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Determinizing L*G") LG = k2.determinize(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting det(L*G)") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.info("Removing epsilons") LG = k2.remove_epsilon(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info("Composing ctc_topo LG") HLG = k2.compose(H, LG, inner_labels='phones') logging.info("Connecting LG") HLG = k2.connect(HLG) logging.info("Arc sorting LG") HLG = k2.arc_sort(HLG) logging.info( f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) # Attach a new attribute `lm_scores` so that we can recover # the `am_scores` later. # The scores on an arc consists of two parts: # scores = am_scores + lm_scores # NOTE: we assume that both kinds of scores are in log-space. HLG.lm_scores = HLG.scores.clone() return HLG
def compile_LG(L: Fsa, G: Fsa, ctc_topo: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. ctc_topo: CTC topology fst, in which when 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L = k2.arc_sort(L) G = k2.arc_sort(G) logging.info("Intersecting L and G") LG = k2.compose(L, G) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting L*G") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Determinizing L*G") LG = k2.determinize(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting det(L*G)") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.info("Removing epsilons") LG = k2.remove_epsilon(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info("Composing ctc_topo LG") LG = k2.compose(ctc_topo, LG, inner_labels='phones') logging.info("Connecting LG") LG = k2.connect(LG) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def test(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: for use_identity_map, sorted_match_a in [(True, True), (False, True), (True, False), (False, False)]: # recognizes (0|1)(0|2) s1 = ''' 0 1 0 0.1 0 1 1 0.2 1 2 0 0.4 1 2 2 0.3 2 3 -1 0.5 3 ''' # recognizes 02* s2 = ''' 0 1 0 1 1 1 2 2 1 2 -1 3 2 ''' # recognizes 1*0 s3 = ''' 0 0 1 10 0 1 0 20 1 2 -1 30 2 ''' a_fsa = k2.Fsa.from_str(s1).to(device) b_fsa_1 = k2.Fsa.from_str(s2).to(device) b_fsa_2 = k2.Fsa.from_str(s3).to(device) a_fsa.requires_grad_(True) b_fsa_1.requires_grad_(True) b_fsa_2.requires_grad_(True) b_fsas = k2.create_fsa_vec([b_fsa_1, b_fsa_2]) if use_identity_map: a_fsas = k2.create_fsa_vec([a_fsa, a_fsa]) b_to_a_map = torch.tensor([0, 1], dtype=torch.int32).to(device) else: a_fsas = k2.create_fsa_vec([a_fsa]) b_to_a_map = torch.tensor([0, 0], dtype=torch.int32).to(device) c_fsas = k2.intersect_device(a_fsas, b_fsas, b_to_a_map, sorted_match_a) assert c_fsas.shape == (2, None, None) c_fsas = k2.connect(c_fsas.to('cpu')) # c_fsas[0] recognizes: 02 # c_fsas[1] recognizes: 10 actual_str_0 = k2.to_str(c_fsas[0]) expected_str_0 = '\n'.join( ['0 1 0 1.1', '1 2 2 2.3', '2 3 -1 3.5', '3']) assert actual_str_0.strip() == expected_str_0 actual_str_1 = k2.to_str(c_fsas[1]) expected_str_1 = '\n'.join( ['0 1 1 10.2', '1 2 0 20.4', '2 3 -1 30.5', '3']) assert actual_str_1.strip() == expected_str_1 loss = c_fsas.scores.sum() (-loss).backward() assert torch.allclose( a_fsa.grad, torch.tensor([-1, -1, -1, -1, -2]).to(a_fsa.grad)) assert torch.allclose( b_fsa_1.grad, torch.tensor([-1, -1, -1]).to(b_fsa_1.grad)) assert torch.allclose( b_fsa_2.grad, torch.tensor([-1, -1, -1]).to(b_fsa_2.grad))
def test_random_case2(self): # 2 sequences for device in self.devices: T1 = torch.randint(10, 200, (1, )).item() T2 = torch.randint(9, 100, (1, )).item() C = torch.randint(20, 30, (1, )).item() if T1 < T2: T1, T2 = T2, T1 torch_activation_1 = torch.rand((T1, C), dtype=torch.float32, device=device).requires_grad_(True) torch_activation_2 = torch.rand((T2, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_1, torch_activation_2], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_1, k2_activation_2], batch_first=True, padding_value=0) target_length1 = torch.randint(1, T1, (1, )).item() target_length2 = torch.randint(1, T2, (1, )).item() target_lengths = torch.tensor([target_length1, target_length2]).to(device) targets = torch.randint(1, C - 1, (target_lengths.sum(), )).to(device) # [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) input_lengths = torch.tensor([T1, T2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert T1 >= T2 supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]], dtype=torch.int32) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([ targets[:target_length1].tolist(), targets[target_length1:].tolist() ]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad, atol=1e-2) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad, atol=1e-2)
def test_case4(self): for device in self.devices: # put case3, case2 and case1 into a batch torch_activation_1 = torch.tensor( [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True) torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to( torch.float32).to(device).requires_grad_(True) torch_activation_3 = torch.tensor([ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]).to(device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) k2_activation_3 = torch_activation_3.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_3, torch_activation_2, torch_activation_1], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_3, k2_activation_2, k2_activation_1], batch_first=True, padding_value=0) # [[b,c], [c,c], [a]] targets = torch.tensor([2, 3, 3, 3, 1]).to(device) input_lengths = torch.tensor([3, 3, 1]).to(device) target_lengths = torch.tensor([2, 2, 1]).to(device) torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) # (T, N, C) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose( torch_loss, torch.tensor([4.938850402832, 7.355742931366, 1.6094379425049]).to(device)) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) supervision_segments = torch.tensor( [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) # [ [b, c], [c, c], [a]] linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.tensor([1., -2, 3.5]).to(device) (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad) assert torch.allclose(torch_activation_3.grad, k2_activation_3.grad)
def nbest_decoding(lats: k2.Fsa, num_paths: int): ''' (Ideas of this function are from Dan) It implements something like CTC prefix beam search using n-best lists The basic idea is to first extra n-best paths from the given lattice, build a word seqs from these paths, and compute the total scores of these sequences in the log-semiring. The one with the max score is used as the decoding output. ''' # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. word_seqs = k2.index(lats.aux_labels, paths) # Note: the above operation supports also the case when # lats.aux_labels is a ragged tensor. In that case, # `remove_axis=True` is used inside the pybind11 binding code, # so the resulting `word_seqs` still has 3 axes, like `paths`. # The 3 axes are [seq][path][word] # Remove epsilons and -1 from word_seqs word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() unique_word_seqs, _, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=False, need_new2old_indexes=True) # Note: unique_word_seqs still has the same axes as word_seqs seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) # lats has phone IDs as labels and word IDs as aux_labels. # inv_lats has word IDs as labels and phone IDs as aux_labels inv_lats = k2.invert(lats) inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted path_lats = k2.intersect_device(inv_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) # path_lats has word IDs as labels and phone IDs as aux_labels path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device)) tot_scores = path_lats.get_tot_scores(True, True) # RaggedFloat currently supports float32 only. # We may bind Ragged<double> as RaggedDouble if needed. ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `paths`, we use `new2old` # here to convert argmax_indexes to the indexes into `paths`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 best_path_indexes = k2.index(new2old, argmax_indexes) paths_2axes = k2.ragged.remove_axis(paths, 0) # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos] best_paths = k2.index(paths_2axes, best_path_indexes) # labels is a k2.RaggedInt with 2 axes [path][phone_id] # Note that it contains -1s. labels = k2.index(lats.labels.contiguous(), best_paths) labels = k2.ragged.remove_values_eq(labels, -1) # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so # aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index(lats.aux_labels, best_paths.values()) best_path_fsas = k2.linear_fsa(labels) best_path_fsas.aux_labels = aux_labels return best_path_fsas
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa, lm_scale_list: List[float] ) -> Dict[str, k2.Fsa]: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. lm_scale_list: A list containing lm_scale values. Returns: A dict of FsaVec, whose key is a lm_scale and the value represents the best decoding path for each sequence in the lattice. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # We will use lm_scores from G, so remove lats.lm_scores here del lats.lm_scores assert hasattr(lats, 'lm_scores') is False # lats.scores = scores / lm_scale # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) print('num_arcs after pruning: ', inverted_lats.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect(rescoring_lats.to('cpu')).to(device)) # inv_lats has phone IDs as labels # and word IDs as aux_labels. inv_lats = k2.invert(rescoring_lats) ans = dict() # # The following implements # scores = (scores - lm_scores)/lm_scale + lm_scores # = scores/lm_scale + lm_scores*(1 - 1/lm_scale) # saved_scores = inv_lats.scores.clone() for lm_scale in lm_scale_list: am_scores = saved_scores - inv_lats.lm_scores am_scores /= lm_scale inv_lats.scores = am_scores + inv_lats.lm_scores best_paths = k2.shortest_path(inv_lats, use_double_scores=True) key = f'lm_scale_{lm_scale}' ans[key] = best_paths return ans
def rescore_with_whole_lattice(lats: k2.Fsa, G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: '''Use whole lattice to rescore. Args: lats: An FsaVec It can be the output of `k2.intersect_dense_pruned`. G_with_epsilon_loops: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. ''' assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') assert G_with_epsilon_loops.shape == (1, None, None) device = lats.device lats.scores = lats.scores - lats.lm_scores # Now, lats.scores contains only am_scores # inverted_lats has word IDs as labels. # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt inverted_lats = k2.invert(lats) num_seqs = lats.shape[0] inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops(inverted_lats) b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) try: rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) except RuntimeError as e: print(f'Caught exception:\n{e}\n') print(f'Number of FSAs: {inverted_lats.shape[0]}') print('num_arcs before pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) # NOTE(fangjun): The choice of the threshold 0.01 is arbitrary here # to avoid OOM. We may need to fine tune it. inverted_lats = k2.prune_on_arc_post(inverted_lats, 0.001, True) inverted_lats_with_epsilon_loops = k2.add_epsilon_self_loops( inverted_lats) print('num_arcs after pruning: ', inverted_lats_with_epsilon_loops.arcs.num_elements()) rescoring_lats = k2.intersect_device(G_with_epsilon_loops, inverted_lats_with_epsilon_loops, b_to_a_map, sorted_match_a=True) rescoring_lats = k2.top_sort(k2.connect( rescoring_lats.to('cpu'))).to(device) inverted_rescoring_lats = k2.invert(rescoring_lats) # inverted rescoring_lats has phone IDs as labels # and word IDs as aux_labels. inverted_rescoring_lats = k2.remove_epsilon_self_loops( inverted_rescoring_lats) best_paths = k2.shortest_path(inverted_rescoring_lats, use_double_scores=True) return best_paths
def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, num_paths: int) -> k2.Fsa: '''Decode using n-best list with LM rescoring. `lats` is a decoding lattice, which has 3 axes. This function first extracts `num_paths` paths from `lats` for each sequence using `k2.random_paths`. The `am_scores` of these paths are computed. For each path, its `lm_scores` is computed using `G` (which is an LM). The final `tot_scores` is the sum of `am_scores` and `lm_scores`. The path with the greatest `tot_scores` within a sequence is used as the decoding output. Args: lats: An FsaVec. It can be the output of `k2.intersect_dense_pruned`. G: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. num_paths: It is the size `n` in `n-best` list. Returns: An FsaVec representing the best decoding path for each sequence in the lattice. ''' device = lats.device assert len(lats.shape) == 3 assert hasattr(lats, 'aux_labels') assert hasattr(lats, 'lm_scores') assert G.shape == (1, None, None) assert G.device == device assert hasattr(G, 'aux_labels') is False # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. word_seqs = k2.index(lats.aux_labels, paths) # Remove epsilons and -1 from word_seqs word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. # # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a seq. # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. # num_repeats.num_elements() == unique_word_seqs.num_elements() # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=True, need_new2old_indexes=True) seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map) # Now compute lm_scores b_to_a_map = torch.zeros_like(path_to_seq_map) lm_path_lats = _intersect_device(G, word_fsas_with_epsilon_loops, b_to_a_map=b_to_a_map, sorted_match_a=True) lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device) lm_scores = lm_path_lats.get_tot_scores(True, True) tot_scores = am_scores + lm_scores # Remember that we used `k2.ragged.unique_sequences` to remove repeated # paths to avoid redundant computation in `k2.intersect_device`. # Now we use `num_repeats` to correct the scores for each path. # # NOTE(fangjun): It is commented out as it leads to a worse WER # tot_scores = tot_scores * num_repeats.values() # TODO(fangjun): We may need to add `k2.RaggedDouble` ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) # Use k2.index here since argmax_indexes' dtype is torch.int32 best_path_indexes = k2.index(new2old, argmax_indexes) paths = k2.ragged.remove_axis(paths, 0) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] best_paths = k2.index(paths, best_path_indexes) # labels is a k2.RaggedInt with 2 axes [path][phone_id] # Note that it contains -1s. labels = k2.index(lats.labels.contiguous(), best_paths) labels = k2.ragged.remove_values_eq(labels, -1) # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so # aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index(lats.aux_labels, best_paths.values()) best_path_fsas = k2.linear_fsa(labels) best_path_fsas.aux_labels = aux_labels return best_path_fsas