def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys]) transcripts_lexicon = k2.intersect(transcripts, L_inv) transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon)) transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def build_num_graphs(self, texts: List[str]) -> k2.Fsa: '''Convert transcript to an Fsa with the help of lexicon and word symbol table. Args: texts: Each element is a transcript containing words separated by spaces. For instance, it may be 'HELLO SNOWFALL', which contains two words. Returns: Return an FST (FsaVec) corresponding to the transcript. Its `labels` are phone IDs and `aux_labels` are word IDs. ''' word_ids_list = [] for text in texts: word_ids = [] for word in text.split(' '): if word in self.lexicon.words: word_ids.append(self.lexicon.words[word]) else: word_ids.append(self.oov_id) word_ids_list.append(word_ids) fsa = k2.linear_fsa(word_ids_list, self.device) fsa = k2.add_epsilon_self_loops(fsa) assert fsa.device == self.device num_graphs = k2.intersect(self.L_inv, fsa, treat_epsilons_specially=False).invert_() num_graphs = k2.arc_sort(num_graphs) return num_graphs
def test_fsa_vec(self): symbols = [ [1, 3, 5], [2, 6], [8, 7, 9], ] num_symbols = sum([len(s) for s in symbols]) fsa = k2.linear_fsa(symbols) assert len(fsa.shape) == 3 assert fsa.shape[0] == 3, 'There should be 3 FSAs' expected_arcs = [ # fsa 0 [0, 1, 1], [1, 2, 3], [2, 3, 5], [3, 4, -1], # fsa 1 [0, 1, 2], [1, 2, 6], [2, 3, -1], # fsa 2 [0, 1, 8], [1, 2, 7], [2, 3, 9], [3, 4, -1] ] print(fsa.arcs.values()[:, :-1]) assert torch.allclose( fsa.arcs.values()[:, :-1], # skip the last field `scores` torch.tensor(expected_arcs, dtype=torch.int32)) assert torch.allclose( fsa.scores, torch.zeros(num_symbols + len(symbols), dtype=torch.float32))
def test_case1(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # suppose we have four symbols: <blk>, a, b, c, d torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]).to(device) k2_activation = torch_activation.detach().clone() # (T, N, C) torch_activation = torch_activation.reshape( 1, 1, -1).requires_grad_(True) # (N, T, C) k2_activation = k2_activation.reshape(1, 1, -1).requires_grad_(True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its label is `a` targets = torch.tensor([1]).to(device) input_lengths = torch.tensor([1]).to(device) target_lengths = torch.tensor([1]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose(torch_loss, torch.tensor([1.6094379425049]).to(device)) # (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([1]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def compile_one_and_cache(self, text: str) -> Fsa: tokens = (token if token in self.vocab._sym2id else self.oov for token in text.split(' ')) word_ids = [self.vocab.get(token) for token in tokens] fsa = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) return decoding_graph
def test_single_fsa(self): for device in self.devices: labels = [2, 0, 0, 0, 5, 8] src = k2.linear_fsa(labels, device) dst = k2.linear_fsa_with_self_loops(src) assert src.device == dst.device expected_labels = [0, 2, 0, 5, 0, 8, 0, -1] assert dst.labels.tolist() == expected_labels
def test_case3(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # (T, N, C) torch_activation = torch.tensor([[ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]]).permute(1, 0, 2).to(device).requires_grad_(True) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `b,c` targets = torch.tensor([2, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([2, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([4.938850402832]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def generate_nbest_list(lats: Fsa, num_paths: int) -> Nbest: '''Generate an n-best list from a lattice. Args: lats: The decoding lattice from the first pass after LM rescoring. lats is an FsaVec. It can be the return value of :func:`whole_lattice_rescoring` num_paths: Size of n for n-best list. CAUTION: After removing paths that represent the same token sequences, the number of paths in different sequences may not be equal. Return: Return an Nbest object. Note the returned FSAs don't have epsilon self-loops. ''' assert len(lats.shape) == 3 # CAUTION: We use `phones` instead of `tokens` here because # :func:`compile_HLG` uses `phones` # # Note: compile_HLG is from k2-fsa/snowfall assert hasattr(lats, 'phones') assert not hasattr(lats, 'tokens') lats.tokens = lats.phones # we use tokens instead of phones in the following code # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedTensor with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # token_seqs is a k2.RaggedTensor sharing the same shape as `paths` # but it contains token IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. # Its axes are [seq][path][token_id] token_seqs = k2.ragged.index(lats.tokens, paths) # Remove epsilons (0s) and -1 from token_seqs token_seqs = token_seqs.remove_values_leq(0) # unique_token_seqs is still a k2.RaggedTensor with axes # [seq][path]token_id]. # But then number of paths in each sequence may be different. unique_token_seqs, _, _ = token_seqs.unique(need_num_repeats=False, need_new2old_indexes=False) seq_to_path_shape = unique_token_seqs.shape.get_layer(0) # Remove the seq axis. # Now unique_token_seqs has only two axes [path][token_id] unique_token_seqs = unique_token_seqs.remove_axis(0) token_fsas = k2.linear_fsa(unique_token_seqs) return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
def compile_one_and_cache(self, text: str) -> k2.Fsa: tokens = (token if token in self.words else self.oov for token in text.split(' ')) word_ids = [self.words[token] for token in tokens] fsa = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() decoding_graph = k2.arc_sort(decoding_graph) decoding_graph = k2.compose(self.ctc_topo, decoding_graph) decoding_graph = k2.connect(decoding_graph) return decoding_graph
def test_random_case1(self): # 1 sequence devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: T = torch.randint(10, 100, (1,)).item() C = torch.randint(20, 30, (1,)).item() torch_activation = torch.rand((1, T + 10, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) # [N, T, C] -> [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activation.permute(1, 0, 2), dim=-1) input_lengths = torch.tensor([T]).to(device) target_lengths = torch.randint(1, T, (1,)).to(device) targets = torch.randint(1, C - 1, (target_lengths.item(),)).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([targets.tolist()]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation.grad, k2_activation.grad, atol=1e-2)
def test_single_fsa(self): symbols = [2, 5, 8] fsa = k2.linear_fsa(symbols) assert len(fsa.shape) == 2 assert fsa.shape[0] == len(symbols) + 2, 'There should be 5 states' assert torch.allclose( fsa.scores, torch.zeros(len(symbols) + 1, dtype=torch.float32)) assert torch.allclose( fsa.arcs.values()[:, :-1], # skip the last field `scores` torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]], dtype=torch.int32))
def test_case1(self): for device in self.devices: # suppose we have four symbols: <blk>, a, b, c, d torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]).to(device) k2_activation = torch_activation.detach().clone() # (T, N, C) torch_activation = torch_activation.reshape( 1, 1, -1).requires_grad_(True) # (N, T, C) k2_activation = k2_activation.reshape(1, 1, -1).requires_grad_(True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its label is `a` targets = torch.tensor([1]).to(device) input_lengths = torch.tensor([1]).to(device) target_lengths = torch.tensor([1]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='mean') assert torch.allclose(torch_loss, torch.tensor([1.6094379425049]).to(device)) # (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(4) linear_fsa = k2.linear_fsa([1]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='mean', target_lengths=target_lengths) assert torch.allclose(torch_loss, k2_loss) torch_loss.backward() k2_loss.backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test_case3(self): for device in self.devices: # (T, N, C) torch_activation = torch.tensor([[ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]]).permute(1, 0, 2).to(device).requires_grad_(True) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `b,c` targets = torch.tensor([2, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='mean') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(4) linear_fsa = k2.linear_fsa([2, 3]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='mean', target_lengths=target_lengths) expected_loss = torch.tensor([4.938850402832], device=device) / target_lengths assert torch.allclose(torch_loss, k2_loss) assert torch.allclose(torch_loss, expected_loss) torch_loss.backward() k2_loss.backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def create_decoding_graph(texts, L, symbols): word_ids_list = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '<UNK>' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] word_ids_list.append(word_ids) fsa = k2.linear_fsa(word_ids_list) decoding_graph = k2.intersect(fsa, L).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) return decoding_graph
def test_multiple_fsa(self): for device in self.devices: labels = [[2, 0, 0, 0, 5, 0, 0, 0, 8, 0, 0], [1, 2], [0, 0, 0, 3, 0, 2]] src = k2.linear_fsa(labels, device) dst = k2.linear_fsa_with_self_loops(src) assert src.device == dst.device expected_labels0 = [0, 2, 0, 5, 0, 8, 0, -1] expected_labels1 = [0, 1, 0, 2, 0, -1] expected_labels2 = [0, 3, 0, 2, 0, -1] expected_labels = expected_labels0 + expected_labels1 + expected_labels2 # noqa assert dst.labels.tolist() == expected_labels
def create_decoding_graph(texts, graph, symbols): fsas = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '<UNK>' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] fsa = k2.linear_fsa(word_ids) fsa = k2.arc_sort(fsa) decoding_graph = k2.intersect(fsa, graph).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) fsas.append(decoding_graph) return k2.create_fsa_vec(fsas)
def test_case2(self): for device in self.devices: # (T, N, C) torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute( 1, 0, 2).to(device) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `c,c` targets = torch.tensor([3, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([3, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([7.355742931366]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def compile(self, targets: torch.Tensor, target_lengths: torch.Tensor) -> 'k2.Fsa': token_ids_list = [ t[:l].tolist() for t, l in zip(targets, target_lengths) ] # see https://github.com/k2-fsa/k2/issues/835 label_graph = k2.linear_fsa(token_ids_list).to(self.device) label_graph.aux_labels = label_graph.labels.clone() decoding_graphs = compose_with_self_loops(self.base_graph, label_graph) decoding_graphs = k2.arc_sort(decoding_graphs).to(self.device) # make sure the gradient is not accumulated decoding_graphs.requires_grad_(False) return decoding_graphs
def test_random_case1(self): # 1 sequence for device in self.devices: T = torch.randint(10, 100, (1, )).item() C = torch.randint(20, 30, (1, )).item() torch_activation = torch.rand((1, T + 10, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) # [N, T, C] -> [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activation.permute(1, 0, 2), dim=-1) input_lengths = torch.tensor([T]).to(device) target_lengths = torch.randint(1, T, (1, )).to(device) targets = torch.randint(1, C - 1, (target_lengths.item(), )).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='mean') k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo = k2.ctc_topo(C - 1) linear_fsa = k2.linear_fsa([targets.tolist()]) decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device) k2_loss = k2.ctc_loss(decoding_graph, dense_fsa_vec, reduction='mean', target_lengths=target_lengths) assert torch.allclose(torch_loss, k2_loss) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (k2_loss * scale).sum().backward() assert torch.allclose(torch_activation.grad, k2_activation.grad, atol=1e-2)
def test_from_ragged_int_single_fsa(self): for device in self.devices: ragged_int = k2.RaggedInt('[ [10 20] ]').to(device) fsa = k2.linear_fsa(ragged_int) assert fsa.shape == (1, None, None) assert fsa.device == device expected_arcs = torch.tensor([[0, 1, 10], [1, 2, 20], [2, 3, -1]], dtype=torch.int32, device=device) assert torch.all( torch.eq( fsa.arcs.values()[:, :-1], # skip the last field `scores` expected_arcs)) assert torch.all(torch.eq(fsa.scores, torch.zeros_like(fsa.scores)))
def nbest_am_lm_scores( lats: k2.Fsa, num_paths: int, device: str = "cuda", batch_size: int = 500, ): """Compute am scores with word_seqs Compatible with both ctc_decoding or TLG decoding. """ paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) if isinstance(lats.aux_labels, torch.Tensor): word_seqs = k2.ragged.index(lats.aux_labels.contiguous(), paths) else: # '_k2.RaggedInt' object has no attribute 'contiguous' word_seqs = lats.aux_labels.index(paths) word_seqs = word_seqs.remove_axis(word_seqs.num_axes - 2) # With ctc_decoding, word_seqs stores token_ids. # With TLG decoding, word_seqs stores word_ids. word_seqs = word_seqs.remove_values_leq(0) unique_word_seqs, num_repeats, new2old = word_seqs.unique( need_num_repeats=True, need_new2old_indexes=True ) seq_to_path_shape = unique_word_seqs.shape.get_layer(0) path_to_seq_map = seq_to_path_shape.row_ids(1) # used to split final computed tot_scores seq_to_path_splits = seq_to_path_shape.row_splits(1) unique_word_seqs = unique_word_seqs.remove_axis(0) word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) am_scores, lm_scores = compute_am_scores_and_lm_scores( lats, word_fsas_with_epsilon_loops, path_to_seq_map, device, batch_size ) token_seqs = k2.ragged.index(lats.labels.contiguous(), paths) token_seqs = token_seqs.remove_axis(0) token_ids, _ = token_seqs.index(new2old, axis=0) token_ids = token_ids.tolist() # Now remove repeated tokens and 0s and -1s. token_ids = [remove_repeated_and_leq(tokens) for tokens in token_ids] return am_scores, lm_scores, token_ids, new2old, path_to_seq_map, seq_to_path_splits
def test_single_fsa(self): for device in self.devices: labels = [2, 5, 8] fsa = k2.linear_fsa(labels, device) assert fsa.device == device assert len(fsa.shape) == 2 assert fsa.shape[0] == len(labels) + 2, 'There should be 5 states' assert torch.all(torch.eq(fsa.scores, torch.zeros_like(fsa.scores))) assert torch.all( torch.eq( fsa.arcs.values()[:, :-1], # skip the last field `scores` torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]], dtype=torch.int32, device=device)))
def test_single_fsa(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: labels = [2, 5, 8] fsa = k2.linear_fsa(labels, device) assert fsa.device == device assert len(fsa.shape) == 2 assert fsa.shape[0] == len(labels) + 2, 'There should be 5 states' assert torch.allclose(fsa.scores, torch.zeros(len(labels) + 1).to(fsa.scores)) assert torch.all( torch.eq( fsa.arcs.values()[:, :-1], # skip the last field `scores` torch.tensor([[0, 1, 2], [1, 2, 5], [2, 3, 8], [3, 4, -1]], dtype=torch.int32, device=device)))
def test_from_ragged_int_two_fsas(self): for device in self.devices: ragged = k2.RaggedTensor([[10, 20], [100, 200, 300]]).to(device) fsa = k2.linear_fsa(ragged) assert fsa.shape == (2, None, None) assert fsa.device == device expected_arcs = torch.tensor( [[0, 1, 10], [1, 2, 20], [2, 3, -1], [0, 1, 100], [1, 2, 200], [2, 3, 300], [3, 4, -1]], dtype=torch.int32, device=device) assert torch.all( torch.eq( fsa.arcs.values()[:, :-1], # skip the last field `scores` expected_arcs)) assert torch.all(torch.eq(fsa.scores, torch.zeros_like(fsa.scores)))
def compile_one_and_cache(self, text: str) -> k2.Fsa: '''Convert transcript to an Fsa with the help of lexicon and word symbol table. Args: text: The transcript containing words separated by spaces. Returns: Return an FST corresponding to the transcript. Its `labels` are phone IDs and `aux_labels` are word IDs. ''' tokens = (token if token in self.words else self.oov for token in text.split(' ')) word_ids = [self.words[token] for token in tokens] fsa = k2.linear_fsa(word_ids) num_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() num_graph = k2.arc_sort(num_graph) return num_graph
def create_decoding_graph(texts, L, symbols): fsas = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '<UNK>' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] fsa = k2.linear_fsa(word_ids) print("linear fsa is ", fsa) fsa = k2.arc_sort(fsa) print("linear fsa, arc-sorted, is ", fsa) print("begin") print(k2.is_arc_sorted(k2.get_properties(fsa))) decoding_graph = k2.intersect(fsa, L).invert_() print("linear fsa, composed, is ", fsa) print("decoding graph is ", decoding_graph) decoding_graph = k2.add_epsilon_self_loops(decoding_graph) print("decoding graph with self-loops is ", decoding_graph) fsas.append(decoding_graph) return k2.create_fsa_vec(fsas)
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) device = L_inv.device transcripts = k2.create_fsa_vec( [k2.linear_fsa(x, device=device) for x in ys]) transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) transcripts_lexicon = k2.intersect(L_inv, transcripts_with_self_loops, treat_epsilons_specially=False) # Don't call invert_() above because we want to return phone IDs, # which is the `aux_labels` of transcripts_lexicon transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.top_sort(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def test_from_ragged_int_two_fsas(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: ragged_int = k2.RaggedInt('[ [10 20] [100 200 300] ]').to(device) fsa = k2.linear_fsa(ragged_int) assert fsa.shape == (2, None, None) assert fsa.device == device expected_arcs = torch.tensor( [[0, 1, 10], [1, 2, 20], [2, 3, -1], [0, 1, 100], [1, 2, 200], [2, 3, 300], [3, 4, -1]], dtype=torch.int32, device=device) assert torch.all( torch.eq( fsa.arcs.values()[:, :-1], # skip the last field `scores` expected_arcs)) assert torch.all(torch.eq(fsa.scores, torch.zeros_like(fsa.scores)))
def test_fsa_vec(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: labels = [ [1, 3, 5], [2, 6], [8, 7, 9], ] num_labels = sum([len(s) for s in labels]) fsa = k2.linear_fsa(labels, device) assert len(fsa.shape) == 3 assert fsa.device == device assert fsa.shape[0] == 3, 'There should be 3 FSAs' expected_arcs = [ # fsa 0 [0, 1, 1], [1, 2, 3], [2, 3, 5], [3, 4, -1], # fsa 1 [0, 1, 2], [1, 2, 6], [2, 3, -1], # fsa 2 [0, 1, 8], [1, 2, 7], [2, 3, 9], [3, 4, -1] ] assert torch.all( torch.eq( fsa.arcs.values()[:, :-1], # skip the last field `scores` torch.tensor(expected_arcs, dtype=torch.int32, device=device))) assert torch.allclose( fsa.scores, torch.zeros(num_labels + len(labels)).to(fsa.scores))
def test_fsa_vec(self): for device in self.devices: labels = [ [1, 3, 5], [2, 6], [8, 7, 9], ] fsa = k2.linear_fsa(labels, device) assert len(fsa.shape) == 3 assert fsa.device == device assert fsa.shape[0] == 3, 'There should be 3 FSAs' expected_arcs = [ # fsa 0 [0, 1, 1], [1, 2, 3], [2, 3, 5], [3, 4, -1], # fsa 1 [0, 1, 2], [1, 2, 6], [2, 3, -1], # fsa 2 [0, 1, 8], [1, 2, 7], [2, 3, 9], [3, 4, -1] ] assert torch.all( torch.eq( fsa.arcs.values()[:, :-1], # skip the last field `scores` torch.tensor(expected_arcs, dtype=torch.int32, device=device))) assert torch.all(torch.eq(fsa.scores, torch.zeros_like(fsa.scores)))