def test_no_repeated(self): # standard ctc topo and modified ctc topo # should be equivalent if there are no # repeated neighboring symbols in the transcript max_token = 3 standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) transcript = k2.linear_fsa([1, 2, 3]) standard_graph = k2.compose(standard, transcript) modified_graph = k2.compose(modified, transcript) input1 = k2.linear_fsa([1, 1, 1, 0, 0, 2, 2, 3, 3]) input2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 3, 3]) inputs = [input1, input2] for i in inputs: lattice1 = k2.intersect(standard_graph, i, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, i, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0] aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels1 = aux_labels1[:-1] # remove -1 aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, aux_labels2)) assert torch.all(torch.eq(aux_labels2, torch.tensor([1, 2, 3])))
def test(self): # for the symbol table # <eps> 0 # a 0 # b 1 # c 2 # an FSA that recognizes a+(b|c) s = ''' 0 1 1 0.1 1 1 1 0.2 1 2 2 0.3 1 3 3 0.4 2 4 -1 0.5 3 4 -1 0.6 5 ''' a_fsa = k2.Fsa.from_str(s) a_fsa.requires_grad_(True) # an FSA that recognizes ab s = ''' 0 1 1 10 1 2 2 20 2 3 -1 30 3 ''' b_fsa = k2.Fsa.from_str(s) b_fsa.requires_grad_(True) fsa = k2.intersect(a_fsa, b_fsa) assert len(fsa.shape) == 2 actual_str = k2.to_str(fsa) expected_str = '\n'.join( ['0 1 1 10.1', '1 2 2 20.3', '2 3 -1 30.5', '3']) assert actual_str.strip() == expected_str loss = fsa.scores.sum() loss.backward() # arc 0, 2, and 4 of a_fsa are kept in the final intersected FSA assert torch.allclose( a_fsa.scores.grad, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.float32)) assert torch.allclose(b_fsa.scores.grad, torch.tensor([1, 1, 1], dtype=torch.float32)) # if any of the input FSA is an FsaVec, # the outupt FSA is also an FsaVec. a_fsa.scores.grad = None b_fsa.scores.grad = None a_fsa = k2.create_fsa_vec([a_fsa]) fsa = k2.intersect(a_fsa, b_fsa) assert len(fsa.shape) == 3
def test_treat_epsilon_specially_false(self): devices = [torch.device('cpu')] if torch.cuda.is_available() and k2.with_cuda: devices.append(torch.device('cuda')) for device in devices: # a_fsa recognizes `(0|1)2*` s1 = ''' 0 1 0 0.1 0 1 1 0.2 1 1 2 0.3 1 2 -1 0.4 2 ''' a_fsa = k2.Fsa.from_str(s1).to(device) a_fsa.requires_grad_(True) # b_fsa recognizes `1|2` s2 = ''' 0 1 1 1 0 1 2 2 1 2 -1 3 2 ''' b_fsa = k2.Fsa.from_str(s2).to(device) b_fsa.requires_grad_(True) # fsa recognizes `1` fsa = k2.intersect(a_fsa, b_fsa, treat_epsilons_specially=False) assert len(fsa.shape) == 2 actual_str = k2.to_str_simple(fsa) expected_str = '\n'.join(['0 1 1 1.2', '1 2 -1 3.4', '2']) assert actual_str.strip() == expected_str loss = fsa.scores.sum() (-loss).backward() # arc 1 and 3 of a_fsa are kept in the final intersected FSA assert torch.allclose(a_fsa.grad, torch.tensor([0, -1, 0, -1]).to(a_fsa.grad)) # arc 0 and 2 of b_fsa are kept in the final intersected FSA assert torch.allclose(b_fsa.grad, torch.tensor([-1, 0, -1]).to(b_fsa.grad)) # if any of the input FSA is an FsaVec, # the outupt FSA is also an FsaVec. a_fsa.scores.grad = None b_fsa.scores.grad = None a_fsa = k2.create_fsa_vec([a_fsa]) fsa = k2.intersect(a_fsa, b_fsa, treat_epsilons_specially=False) assert len(fsa.shape) == 3
def test_treat_epsilon_specially_true(self): # this version works only on CPU and requires # arc-sorted inputs # a_fsa recognizes `(1|3)?2*` s1 = ''' 0 1 3 0.0 0 1 1 0.2 0 1 0 0.1 1 1 2 0.3 1 2 -1 0.4 2 ''' a_fsa = k2.Fsa.from_str(s1) a_fsa.requires_grad_(True) # b_fsa recognizes `1|2|5` s2 = ''' 0 1 5 0 0 1 1 1 0 1 2 2 1 2 -1 3 2 ''' b_fsa = k2.Fsa.from_str(s2) b_fsa.requires_grad_(True) # fsa recognizes 1|2 fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa)) assert len(fsa.shape) == 2 actual_str = k2.to_str_simple(fsa) expected_str = '\n'.join( ['0 1 0 0.1', '0 2 1 1.2', '1 2 2 2.3', '2 3 -1 3.4', '3']) assert actual_str.strip() == expected_str loss = fsa.scores.sum() (-loss).backward() # arc 1, 2, 3, and 4 of a_fsa are kept in the final intersected FSA assert torch.allclose(a_fsa.grad, torch.tensor([0, -1, -1, -1, -1]).to(a_fsa.grad)) # arc 1, 2, and 3 of b_fsa are kept in the final intersected FSA assert torch.allclose(b_fsa.grad, torch.tensor([0, -1, -1, -1]).to(b_fsa.grad)) # if any of the input FSA is an FsaVec, # the outupt FSA is also an FsaVec. a_fsa.scores.grad = None b_fsa.scores.grad = None a_fsa = k2.create_fsa_vec([a_fsa]) fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa)) assert len(fsa.shape) == 3
def test_with_repeated(self): max_token = 2 standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) transcript = k2.linear_fsa([1, 2, 2]) standard_graph = k2.compose(standard, transcript) modified_graph = k2.compose(modified, transcript) # There is a blank separating 2 in the input # so standard and modified ctc topo should be equivalent input = k2.linear_fsa([1, 1, 2, 2, 0, 2, 2, 0, 0]) lattice1 = k2.intersect(standard_graph, input, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, input, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0] aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels1 = aux_labels1[:-1] # remove -1 aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, aux_labels2)) assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2]))) # There are no blanks separating 2 in the input. # The standard ctc topo requires that there must be a blank # separating 2, so lattice1 in the following is empty input = k2.linear_fsa([1, 1, 2, 2, 0, 0]) lattice1 = k2.intersect(standard_graph, input, treat_epsilons_specially=False) lattice2 = k2.intersect(modified_graph, input, treat_epsilons_specially=False) lattice1 = k2.connect(lattice1) lattice2 = k2.connect(lattice2) assert lattice1.num_arcs == 0 # Since there are two 2s in the input and there are also two 2s # in the transcript, the final output contains only one path. # If there were more than two 2s in the input, the output # would contain more than one path aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0] aux_labels2 = aux_labels2[:-1] assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))
def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. Returns: A tuple (num_graph, den_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph`. ''' assert P.is_cpu() ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_() ctc_topo_P = k2.connect(ctc_topo_P) num_graphs = k2.create_fsa_vec( [self.compile_one_and_cache(text) for text in texts]) num = k2.compose(ctc_topo_P, num_graphs) num = k2.connect(num) num = k2.arc_sort(num) den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts)) return num, den
def test_case1(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # suppose we have four symbols: <blk>, a, b, c, d torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]).to(device) k2_activation = torch_activation.detach().clone() # (T, N, C) torch_activation = torch_activation.reshape( 1, 1, -1).requires_grad_(True) # (N, T, C) k2_activation = k2_activation.reshape(1, 1, -1).requires_grad_(True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its label is `a` targets = torch.tensor([1]).to(device) input_lengths = torch.tensor([1]).to(device) target_lengths = torch.tensor([1]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose(torch_loss, torch.tensor([1.6094379425049]).to(device)) # (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([1]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def build_num_graphs(self, texts: List[str]) -> k2.Fsa: '''Convert transcript to an Fsa with the help of lexicon and word symbol table. Args: texts: Each element is a transcript containing words separated by spaces. For instance, it may be 'HELLO SNOWFALL', which contains two words. Returns: Return an FST (FsaVec) corresponding to the transcript. Its `labels` are phone IDs and `aux_labels` are word IDs. ''' word_ids_list = [] for text in texts: word_ids = [] for word in text.split(' '): if word in self.words: word_ids.append(self.words[word]) else: word_ids.append(self.oov_id) word_ids_list.append(word_ids) fsa = k2.linear_fsa(word_ids_list, self.device) fsa = k2.add_epsilon_self_loops(fsa) num_graphs = k2.intersect(self.L_inv, fsa, treat_epsilons_specially=False).invert_() num_graphs = k2.arc_sort(num_graphs) return num_graphs
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys]) transcripts_lexicon = k2.intersect(transcripts, L_inv) transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon)) transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def compile_one_and_cache(self, text: str) -> Fsa: tokens = (token if token in self.vocab._sym2id else self.oov for token in text.split(' ')) word_ids = [self.vocab.get(token) for token in tokens] fsa = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) return decoding_graph
def compile(self, texts: Iterable[str], P: k2.Fsa, replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. replicate_den: If True, the returned den_graph is replicated to match the number of FSAs in the returned num_graph; if False, the returned den_graph contains only a single FSA Returns: A tuple (num_graph, den_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph` if replicate_den is True; otherwise, it is an FsaVec containing only a single FSA. ''' assert P.device == self.device P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(self.ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() ctc_topo_P = k2.arc_sort(ctc_topo_P) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) num = k2.compose(ctc_topo_P, num_graphs_with_self_loops, treat_epsilons_specially=False) num = k2.arc_sort(num) ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) if replicate_den: indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) den = k2.index_fsa(ctc_topo_P_vec, indexes) else: den = ctc_topo_P_vec return num, den
def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa': """Intersection helper function. """ assert hasattr(base_graph, "aux_labels") assert not hasattr(aux_graph, "aux_labels") aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device) result = k2.intersect(k2.arc_sort(base_graph), aux_graph_with_self_loops, treat_epsilons_specially=False,) setattr(result, "phones", result.labels) return result
def test_case3(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # (T, N, C) torch_activation = torch.tensor([[ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]]).permute(1, 0, 2).to(device).requires_grad_(True) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `b,c` targets = torch.tensor([2, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([2, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([4.938850402832]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def compile_one_and_cache(self, text: str) -> k2.Fsa: tokens = (token if token in self.words else self.oov for token in text.split(' ')) word_ids = [self.words[token] for token in tokens] fsa = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() decoding_graph = k2.arc_sort(decoding_graph) decoding_graph = k2.compose(self.ctc_topo, decoding_graph) decoding_graph = k2.connect(decoding_graph) return decoding_graph
def test_random_case1(self): # 1 sequence devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: T = torch.randint(10, 100, (1,)).item() C = torch.randint(20, 30, (1,)).item() torch_activation = torch.rand((1, T + 10, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) # [N, T, C] -> [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activation.permute(1, 0, 2), dim=-1) input_lengths = torch.tensor([T]).to(device) target_lengths = torch.randint(1, T, (1,)).to(device) targets = torch.randint(1, C - 1, (target_lengths.item(),)).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([targets.tolist()]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation.grad, k2_activation.grad, atol=1e-2)
def create_decoding_graph(texts, L, symbols): word_ids_list = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '<UNK>' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] word_ids_list.append(word_ids) fsa = k2.linear_fsa(word_ids_list) decoding_graph = k2.intersect(fsa, L).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) return decoding_graph
def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. Returns: A tuple (num_graph, den_graph, decoding_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. It is the result of compose(ctc_topo, P, L, transcript) - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph`. It is the result of compose(ctc_topo, P). - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G) Note that it is a single Fsa, not an FsaVec. ''' assert P.device == self.device P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(self.ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() ctc_topo_P = k2.arc_sort(ctc_topo_P) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) num = k2.compose(ctc_topo_P, num_graphs_with_self_loops, treat_epsilons_specially=False, inner_labels='phones') num = k2.arc_sort(num) ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) den = k2.index_fsa(ctc_topo_P_vec, indexes) return num, den, self.decoding_graph
def create_decoding_graph(texts, graph, symbols): fsas = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '<UNK>' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] fsa = k2.linear_fsa(word_ids) fsa = k2.arc_sort(fsa) decoding_graph = k2.intersect(fsa, graph).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) fsas.append(decoding_graph) return k2.create_fsa_vec(fsas)
def __init__(self, lexicon: Lexicon, P: k2.Fsa, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. P: A phone bigram LM if the pronunciations in the lexicon are in phones; a word piece bigram if the pronunciations in the lexicon are word pieces. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' self.lexicon = lexicon L_inv = self.lexicon.L_inv.to(device) P = P.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert L_inv.requires_grad is False assert oov in self.lexicon.words self.L_inv = L_inv self.oov_id = self.lexicon.words[oov] self.oov = oov self.device = device phone_symbols = get_phone_symbols(self.lexicon.phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def test_case2(self): for device in self.devices: # (T, N, C) torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute( 1, 0, 2).to(device) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `c,c` targets = torch.tensor([3, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([3, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([7.355742931366]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 LG = k2.add_epsilon_self_loops(LG) LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def compile_one_and_cache(self, text: str) -> k2.Fsa: '''Convert transcript to an Fsa with the help of lexicon and word symbol table. Args: text: The transcript containing words separated by spaces. Returns: Return an FST corresponding to the transcript. Its `labels` are phone IDs and `aux_labels` are word IDs. ''' tokens = (token if token in self.words else self.oov for token in text.split(' ')) word_ids = [self.words[token] for token in tokens] fsa = k2.linear_fsa(word_ids) num_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() num_graph = k2.arc_sort(num_graph) return num_graph
def test_composition_equivalence(self): index = _generate_fsa_vec() index = k2.arc_sort(k2.connect(k2.remove_epsilon(index))) src = _generate_fsa_vec() replace = k2.replace_fsa(src, index, 1) replace = k2.top_sort(replace) f_fsa = _construct_f(src) f_fsa = k2.arc_sort(f_fsa) intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True) intersect = k2.invert(intersect) intersect = k2.top_sort(intersect) delattr(intersect, 'aux_labels') assert k2.is_rand_equivalent(replace, intersect, log_semiring=True, delta=1e-3)
def create_decoding_graph(texts, L, symbols): fsas = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '<UNK>' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] fsa = k2.linear_fsa(word_ids) print("linear fsa is ", fsa) fsa = k2.arc_sort(fsa) print("linear fsa, arc-sorted, is ", fsa) print("begin") print(k2.is_arc_sorted(k2.get_properties(fsa))) decoding_graph = k2.intersect(fsa, L).invert_() print("linear fsa, composed, is ", fsa) print("decoding graph is ", decoding_graph) decoding_graph = k2.add_epsilon_self_loops(decoding_graph) print("decoding graph with self-loops is ", decoding_graph) fsas.append(decoding_graph) return k2.create_fsa_vec(fsas)
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) device = L_inv.device transcripts = k2.create_fsa_vec( [k2.linear_fsa(x, device=device) for x in ys]) transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) transcripts_lexicon = k2.intersect(L_inv, transcripts_with_self_loops, treat_epsilons_specially=False) # Don't call invert_() above because we want to return phone IDs, # which is the `aux_labels` of transcripts_lexicon transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.top_sort(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def compile_LG(L: Fsa, G: Fsa, ctc_topo_inv: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. ctc_topo_inv: Epsilons are in `aux_labels` and `labels` contain phone IDs. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.debug("Removing epsilons") LG = k2.remove_epsilons_iterative_tropical(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.debug("Arc sorting") LG = k2.arc_sort(LG) logging.debug("Composing") LG = k2.compose(ctc_topo_inv, LG) logging.debug("Connecting") LG = k2.connect(LG) logging.debug("Arc sorting") LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def visualize_ctc_topo(): '''This function shows how to visualize standard/modified ctc topologies. It's for demonstration only, not for testing. ''' max_token = 2 labels_sym = k2.SymbolTable.from_str(''' <blk> 0 z 1 o 2 ''') aux_labels_sym = k2.SymbolTable.from_str(''' z 1 o 2 ''') word_sym = k2.SymbolTable.from_str(''' zoo 1 ''') standard = k2.ctc_topo(max_token, modified=False) modified = k2.ctc_topo(max_token, modified=True) standard.labels_sym = labels_sym standard.aux_labels_sym = aux_labels_sym modified.labels_sym = labels_sym modified.aux_labels_sym = aux_labels_sym standard.draw('standard_topo.svg', title='standard CTC topo') modified.draw('modified_topo.svg', title='modified CTC topo') fsa = k2.linear_fst([1, 2, 2], [1, 0, 0]) fsa.labels_sym = labels_sym fsa.aux_labels_sym = word_sym fsa.draw('transcript.svg', title='transcript') standard_graph = k2.compose(standard, fsa) modified_graph = k2.compose(modified, fsa) standard_graph.draw('standard_graph.svg', title='standard graph') modified_graph.draw('modified_graph.svg', title='modified graph') # z z <blk> <blk> o o <blk> o <blk> inputs = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 2, 0]) inputs.labels_sym = labels_sym inputs.draw('inputs.svg', title='inputs') standard_lattice = k2.intersect(standard_graph, inputs, treat_epsilons_specially=False) standard_lattice.draw('standard_lattice.svg', title='standard lattice') modified_lattice = k2.intersect(modified_graph, inputs, treat_epsilons_specially=False) modified_lattice = k2.connect(modified_lattice) modified_lattice.draw('modified_lattice.svg', title='modified lattice') # z z <blk> <blk> o o o <blk> inputs2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 2, 0]) inputs2.labels_sym = labels_sym inputs2.draw('inputs2.svg', title='inputs2') standard_lattice2 = k2.intersect(standard_graph, inputs2, treat_epsilons_specially=False) standard_lattice2 = k2.connect(standard_lattice2) # It's empty since the topo requires that there must be a blank # between the two o's in zoo assert standard_lattice2.num_arcs == 0 standard_lattice2.draw('standard_lattice2.svg', title='standard lattice2') modified_lattice2 = k2.intersect(modified_graph, inputs2, treat_epsilons_specially=False) modified_lattice2 = k2.connect(modified_lattice2) modified_lattice2.draw('modified_lattice2.svg', title='modified lattice2')
def test_case4(self): for device in self.devices: # put case3, case2 and case1 into a batch torch_activation_1 = torch.tensor( [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True) torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to( torch.float32).to(device).requires_grad_(True) torch_activation_3 = torch.tensor([ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]).to(device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) k2_activation_3 = torch_activation_3.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_3, torch_activation_2, torch_activation_1], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_3, k2_activation_2, k2_activation_1], batch_first=True, padding_value=0) # [[b,c], [c,c], [a]] targets = torch.tensor([2, 3, 3, 3, 1]).to(device) input_lengths = torch.tensor([3, 3, 1]).to(device) target_lengths = torch.tensor([2, 2, 1]).to(device) torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) # (T, N, C) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose( torch_loss, torch.tensor([4.938850402832, 7.355742931366, 1.6094379425049]).to(device)) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) supervision_segments = torch.tensor( [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) # [ [b, c], [c, c], [a]] linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.tensor([1., -2, 3.5]).to(device) (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad) assert torch.allclose(torch_activation_3.grad, k2_activation_3.grad)
def main(): # load L, G, symbol_table lang_dir = 'data/lang_nosp' with open(lang_dir + '/L.fst.txt') as f: L = k2.Fsa.from_openfst(f.read(), acceptor=False) with open(lang_dir + '/G.fsa.txt') as f: G = k2.Fsa.from_openfst(f.read(), acceptor=True) with open(lang_dir + '/words.txt') as f: symbol_table = k2.SymbolTable.from_str(f.read()) L = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) graph = k2.intersect(L, G) graph = k2.arc_sort(graph) # load dataset feature_dir = 'exp/data1' cuts_train = CutSet.from_json(feature_dir + '/cuts_train-clean-100.json.gz') cuts_dev = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz') train = K2SpeechRecognitionIterableDataset(cuts_train, shuffle=True) validate = K2SpeechRecognitionIterableDataset(cuts_dev, shuffle=False) train_dl = torch.utils.data.DataLoader(train, batch_size=None, num_workers=1) valid_dl = torch.utils.data.DataLoader(validate, batch_size=None, num_workers=1) dir = 'exp' setup_logger('{}/log/log-train'.format(dir)) if not torch.cuda.is_available(): logging.error('No GPU detected!') sys.exit(-1) device_id = 0 device = torch.device('cuda', device_id) model = Wav2Letter(num_classes=364, input_type='mfcc', num_features=40) model.to(device) learning_rate = 0.001 start_epoch = 0 num_epochs = 10 best_objf = 100000 best_epoch = start_epoch best_model_path = os.path.join(dir, 'best_model.pt') best_epoch_info_filename = os.path.join(dir, 'best-epoch-info') optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) # optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) for epoch in range(start_epoch, num_epochs): curr_learning_rate = learning_rate * pow(0.4, epoch) for param_group in optimizer.param_groups: param_group['lr'] = curr_learning_rate logging.info('epoch {}, learning rate {}'.format( epoch, curr_learning_rate)) objf = train_one_epoch(dataloader=train_dl, valid_dataloader=valid_dl, model=model, device=device, graph=graph, symbols=symbol_table, optimizer=optimizer, current_epoch=epoch, num_epochs=num_epochs) if objf < best_objf: best_objf = objf best_epoch = epoch save_checkpoint(filename=best_model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf) save_training_info(filename=best_epoch_info_filename, model_path=best_model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=best_objf, best_objf=best_objf, best_epoch=best_epoch) # we always save the model for every epoch model_path = os.path.join(dir, 'epoch-{}.pt'.format(epoch)) save_checkpoint(filename=model_path, model=model, epoch=epoch, learning_rate=curr_learning_rate, objf=objf) epoch_info_filename = os.path.join(dir, 'epoch-{}-info'.format(epoch)) save_training_info(filename=epoch_info_filename, model_path=model_path, current_epoch=epoch, learning_rate=curr_learning_rate, objf=objf, best_objf=best_objf, best_epoch=best_epoch) logging.warning('Done')
def test_random_case2(self): # 2 sequences for device in self.devices: T1 = torch.randint(10, 200, (1, )).item() T2 = torch.randint(9, 100, (1, )).item() C = torch.randint(20, 30, (1, )).item() if T1 < T2: T1, T2 = T2, T1 torch_activation_1 = torch.rand((T1, C), dtype=torch.float32, device=device).requires_grad_(True) torch_activation_2 = torch.rand((T2, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_1, torch_activation_2], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_1, k2_activation_2], batch_first=True, padding_value=0) target_length1 = torch.randint(1, T1, (1, )).item() target_length2 = torch.randint(1, T2, (1, )).item() target_lengths = torch.tensor([target_length1, target_length2]).to(device) targets = torch.randint(1, C - 1, (target_lengths.sum(), )).to(device) # [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) input_lengths = torch.tensor([T1, T2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert T1 >= T2 supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]], dtype=torch.int32) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([ targets[:target_length1].tolist(), targets[target_length1:].tolist() ]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad, atol=1e-2) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad, atol=1e-2)