def __init__(self, L_inv: k2.Fsa, phones: k2.SymbolTable, words: k2.SymbolTable, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert oov in words self.L_inv = L_inv self.phones = phones self.words = words self.oov = oov phone_ids = get_phone_symbols(phones) phone_ids_with_blank = [0] + phone_ids self.ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
def __init__(self, L_inv: k2.Fsa, phones: k2.SymbolTable, words: k2.SymbolTable, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert oov in words self.L_inv = L_inv self.phones = phones self.words = words self.oov = oov ctc_topo = build_ctc_topo(list(phones._id2sym.keys())) self.ctc_topo = k2.arc_sort(ctc_topo)
def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa': """Intersection helper function. """ assert hasattr(base_graph, "aux_labels") assert not hasattr(aux_graph, "aux_labels") aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device) result = k2.intersect(k2.arc_sort(base_graph), aux_graph_with_self_loops, treat_epsilons_specially=False,) setattr(result, "phones", result.labels) return result
def compile(self, texts: Iterable[str], P: k2.Fsa, replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. replicate_den: If True, the returned den_graph is replicated to match the number of FSAs in the returned num_graph; if False, the returned den_graph contains only a single FSA Returns: A tuple (num_graph, den_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph` if replicate_den is True; otherwise, it is an FsaVec containing only a single FSA. ''' assert P.device == self.device P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(self.ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() ctc_topo_P = k2.arc_sort(ctc_topo_P) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) num = k2.compose(ctc_topo_P, num_graphs_with_self_loops, treat_epsilons_specially=False) num = k2.arc_sort(num) ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) if replicate_den: indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) den = k2.index_fsa(ctc_topo_P_vec, indexes) else: den = ctc_topo_P_vec return num, den
def test_empty_fsa(self): array_size = k2.IntArray2Size(0, 0) fsa = k2.Fsa.create_fsa_with_size(array_size) arc_map = k2.IntArray1.create_array_with_size(fsa.size2) k2.arc_sort(fsa, arc_map) self.assertTrue(k2.is_empty(fsa)) self.assertTrue(arc_map.empty()) # test without arc_map k2.arc_sort(fsa) self.assertTrue(k2.is_empty(fsa))
def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. Returns: A tuple (num_graph, den_graph, decoding_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. It is the result of compose(ctc_topo, P, L, transcript) - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph`. It is the result of compose(ctc_topo, P). - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G) Note that it is a single Fsa, not an FsaVec. ''' assert P.device == self.device P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(self.ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() ctc_topo_P = k2.arc_sort(ctc_topo_P) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) num = k2.compose(ctc_topo_P, num_graphs_with_self_loops, treat_epsilons_specially=False, inner_labels='phones') num = k2.arc_sort(num) ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) den = k2.index_fsa(ctc_topo_P_vec, indexes) return num, den, self.decoding_graph
def test_treat_epsilon_specially_true(self): # this version works only on CPU and requires # arc-sorted inputs # a_fsa recognizes `(1|3)?2*` s1 = ''' 0 1 3 0.0 0 1 1 0.2 0 1 0 0.1 1 1 2 0.3 1 2 -1 0.4 2 ''' a_fsa = k2.Fsa.from_str(s1) a_fsa.requires_grad_(True) # b_fsa recognizes `1|2|5` s2 = ''' 0 1 5 0 0 1 1 1 0 1 2 2 1 2 -1 3 2 ''' b_fsa = k2.Fsa.from_str(s2) b_fsa.requires_grad_(True) # fsa recognizes 1|2 fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa)) assert len(fsa.shape) == 2 actual_str = k2.to_str_simple(fsa) expected_str = '\n'.join( ['0 1 0 0.1', '0 2 1 1.2', '1 2 2 2.3', '2 3 -1 3.4', '3']) assert actual_str.strip() == expected_str loss = fsa.scores.sum() (-loss).backward() # arc 1, 2, 3, and 4 of a_fsa are kept in the final intersected FSA assert torch.allclose(a_fsa.grad, torch.tensor([0, -1, -1, -1, -1]).to(a_fsa.grad)) # arc 1, 2, and 3 of b_fsa are kept in the final intersected FSA assert torch.allclose(b_fsa.grad, torch.tensor([0, -1, -1, -1]).to(b_fsa.grad)) # if any of the input FSA is an FsaVec, # the outupt FSA is also an FsaVec. a_fsa.scores.grad = None b_fsa.scores.grad = None a_fsa = k2.create_fsa_vec([a_fsa]) fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa)) assert len(fsa.shape) == 3
def __init__(self, lexicon: Lexicon, P: k2.Fsa, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. P: A phone bigram LM if the pronunciations in the lexicon are in phones; a word piece bigram if the pronunciations in the lexicon are word pieces. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' self.lexicon = lexicon L_inv = self.lexicon.L_inv.to(device) P = P.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert L_inv.requires_grad is False assert oov in self.lexicon.words self.L_inv = L_inv self.oov_id = self.lexicon.words[oov] self.oov = oov self.device = device phone_symbols = get_phone_symbols(self.lexicon.phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
def __init__( self, num_classes: int, topo_type: str = "default", topo_with_self_loops: bool = True, device: torch.device = torch.device("cpu"), ): # use k2 import guard k2_import_guard() self.topo_type = topo_type self.device = device self.base_graph = k2.arc_sort( build_topo(topo_type, list(range(num_classes)), topo_with_self_loops)).to(self.device) self.ctc_topo_inv = k2.arc_sort(self.base_graph.invert())
def build_ctc_topo2(phones: List[int]): # See https://github.com/k2-fsa/k2/issues/746#issuecomment-856421616 assert 0 in phones, 'We assume 0 is the ID of the blank symbol' phones = phones.copy() phones.remove(0) num_phones = len(phones) start = 0 final = num_phones + 1 arcs = [] arcs.append([start, start, 0, 0, 0]) arcs.append([start, final, -1, -1, 0]) arcs.append([final]) for i, p in enumerate(phones): i += 1 arcs.append([start, start, p, p, 0]) arcs.append([start, i, p, p, 0]) arcs.append([i, i, p, 0, 0]) arcs.append([i, start, p, 0, 0]) arcs = sorted(arcs, key=lambda arc: arc[0]) arcs = [[str(i) for i in arc] for arc in arcs] arcs = [' '.join(arc) for arc in arcs] arcs = '\n'.join(arcs) ctc_topo = k2.Fsa.from_str(arcs, False) return k2.arc_sort(ctc_topo)
def get_hierarchical_targets(ys: List[List[int]], lexicon: k2.Fsa) -> List[Tensor]: """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). Args: ys: Word level transcripts. lexicon: Its labels are words, while its aux_labels are phones. Returns: List[Tensor]: Phone level transcripts. """ if lexicon is None: return ys else: L_inv = lexicon n_batch = len(ys) indices = torch.tensor(range(n_batch)) transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys]) transcripts_lexicon = k2.intersect(transcripts, L_inv) transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon)) transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) transcripts_lexicon = k2.shortest_path(transcripts_lexicon, use_double_scores=True) ys = get_texts(transcripts_lexicon, indices) ys = [torch.tensor(y) for y in ys] return ys
def test_case1(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # suppose we have four symbols: <blk>, a, b, c, d torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]).to(device) k2_activation = torch_activation.detach().clone() # (T, N, C) torch_activation = torch_activation.reshape( 1, 1, -1).requires_grad_(True) # (N, T, C) k2_activation = k2_activation.reshape(1, 1, -1).requires_grad_(True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its label is `a` targets = torch.tensor([1]).to(device) input_lengths = torch.tensor([1]).to(device) target_lengths = torch.tensor([1]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose(torch_loss, torch.tensor([1.6094379425049]).to(device)) # (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([1]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test1(self): s = ''' 0 4 1 1 0 1 1 1 1 2 2 2 1 3 3 3 2 7 1 4 3 7 1 5 4 6 1 2 4 6 1 3 4 5 1 3 4 8 -1 2 5 8 -1 4 6 8 -1 3 7 8 -1 5 8 ''' fsa = k2.Fsa.from_str(s) prop = fsa.properties self.assertFalse( prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0) dest = k2.determinize(fsa) log_semiring = False self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) arc_sorted = k2.arc_sort(dest) prop = arc_sorted.properties self.assertTrue( prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
def build_ctc_topo(tokens: List[int]) -> k2.Fsa: '''Build CTC topology. The resulting topology converts repeated input symbols to a single output symbol. Caution: The resulting topo is an FST. Epsilons are on the left side (i.e., ilabels) and tokens are on the right side (i.e., olabels) Args: tokens: A list of tokens, e.g., phones, characters, etc. Returns: Returns an FST that converts repeated tokens to a single token. ''' assert 0 in tokens, 'We assume 0 is ID of the blank symbol' num_states = len(tokens) final_state = num_states rules = '' for i in range(num_states): for j in range(num_states): if i == j: rules += f'{i} {i} 0 {tokens[i]} 0.0\n' else: rules += f'{i} {j} {tokens[j]} {tokens[j]} 0.0\n' rules += f'{i} {final_state} -1 -1 0.0\n' rules += f'{final_state}' ans = k2.Fsa.from_str(rules) ans = k2.arc_sort(ans) return ans
def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. Returns: A tuple (num_graph, den_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph`. ''' assert P.is_cpu() ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_() ctc_topo_P = k2.connect(ctc_topo_P) num_graphs = k2.create_fsa_vec( [self.compile_one_and_cache(text) for text in texts]) num = k2.compose(ctc_topo_P, num_graphs) num = k2.connect(num) num = k2.arc_sort(num) den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts)) return num, den
def build_ctc_topo(tokens: List[int]) -> k2.Fsa: """Build CTC topology. A token which appears once on the right side (i.e. olabels) may appear multiple times on the left side (ilabels), possibly with epsilons in between. When 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. That is, 0 has two meanings here. Args: tokens: A list of tokens, e.g., phones, characters, etc. Returns: Returns an FST that converts repeated tokens to a single token. """ assert 0 in tokens, "We assume 0 is ID of the blank symbol" num_states = len(tokens) final_state = num_states arcs = "" for i in range(num_states): for j in range(num_states): if i == j: arcs += f"{i} {i} {tokens[i]} 0 0.0\n" else: arcs += f"{i} {j} {tokens[j]} {tokens[j]} 0.0\n" arcs += f"{i} {final_state} -1 -1 0.0\n" arcs += f"{final_state}" ans = k2.Fsa.from_str(arcs, num_aux_labels=1) ans = k2.arc_sort(ans) return ans
def build_shared_blank_topo(tokens: List[int], with_self_loops: bool = True) -> 'k2.Fsa': """Build the shared blank CTC topology. See https://github.com/k2-fsa/k2/issues/746#issuecomment-856421616 """ assert 0 in tokens, "We assume 0 is the ID of the blank symbol" tokens = tokens.copy() tokens.remove(0) num_tokens = len(tokens) start = 0 final = num_tokens + 1 arcs = [] arcs.append([start, start, 0, 0, 0]) arcs.append([start, final, -1, -1, 0]) arcs.append([final]) for i, p in enumerate(tokens): i += 1 arcs.append([start, start, p, p, 0]) arcs.append([start, i, p, p, 0]) arcs.append([i, start, p, 0, 0]) if with_self_loops: arcs.append([i, i, p, 0, 0]) arcs = sorted(arcs, key=lambda arc: arc[0]) arcs = [[str(i) for i in arc] for arc in arcs] arcs = [" ".join(arc) for arc in arcs] arcs = "\n".join(arcs) ans = k2.Fsa.from_str(arcs, num_aux_labels=1) ans = k2.arc_sort(ans) return ans
def build_num_graphs(self, texts: List[str]) -> k2.Fsa: '''Convert transcript to an Fsa with the help of lexicon and word symbol table. Args: texts: Each element is a transcript containing words separated by spaces. For instance, it may be 'HELLO SNOWFALL', which contains two words. Returns: Return an FST (FsaVec) corresponding to the transcript. Its `labels` are phone IDs and `aux_labels` are word IDs. ''' word_ids_list = [] for text in texts: word_ids = [] for word in text.split(' '): if word in self.words: word_ids.append(self.words[word]) else: word_ids.append(self.oov_id) word_ids_list.append(word_ids) fsa = k2.linear_fsa(word_ids_list, self.device) fsa = k2.add_epsilon_self_loops(fsa) num_graphs = k2.intersect(self.L_inv, fsa, treat_epsilons_specially=False).invert_() num_graphs = k2.arc_sort(num_graphs) return num_graphs
def __init__( self, asr_train_config: Union[Path, str], asr_model_file: Union[Path, str] = None, lm_train_config: Union[Path, str] = None, lm_file: Union[Path, str] = None, token_type: str = None, bpemodel: str = None, device: str = "cpu", maxlenratio: float = 0.0, minlenratio: float = 0.0, batch_size: int = 1, dtype: str = "float32", beam_size: int = 8, ctc_weight: float = 0.5, lm_weight: float = 1.0, penalty: float = 0.0, nbest: int = 1, streaming: bool = False, output_beam_size: int = 8, ): assert check_argument_types() # 1. Build ASR model asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, device) asr_model.to(dtype=getattr(torch, dtype)).eval() token_list = asr_model.token_list self.decode_graph = k2.arc_sort( build_ctc_topo(list(range(len(token_list))))).to(device) if token_type is None: token_type = asr_train_args.token_type if bpemodel is None: bpemodel = asr_train_args.bpemodel if token_type is None: tokenizer = None elif token_type == "bpe": if bpemodel is not None: tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) else: tokenizer = None else: tokenizer = build_tokenizer(token_type=token_type) converter = TokenIDConverter(token_list=token_list) logging.info(f"Text tokenizer: {tokenizer}") logging.info(f"Running on : {device}") self.asr_model = asr_model self.asr_train_args = asr_train_args self.converter = converter self.tokenizer = tokenizer self.device = device self.dtype = dtype self.output_beam_size = output_beam_size
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 LG = k2.add_epsilon_self_loops(LG) LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def test_case3(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # (T, N, C) torch_activation = torch.tensor([[ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]]).permute(1, 0, 2).to(device).requires_grad_(True) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `b,c` targets = torch.tensor([2, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([2, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([4.938850402832]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def compile_one_and_cache(self, text: str) -> k2.Fsa: tokens = (token if token in self.words else self.oov for token in text.split(' ')) word_ids = [self.words[token] for token in tokens] fsa = k2.linear_fsa(word_ids) decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_() decoding_graph = k2.arc_sort(decoding_graph) decoding_graph = k2.compose(self.ctc_topo, decoding_graph) decoding_graph = k2.connect(decoding_graph) return decoding_graph
def test_random_case1(self): # 1 sequence devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: T = torch.randint(10, 100, (1,)).item() C = torch.randint(20, 30, (1,)).item() torch_activation = torch.rand((1, T + 10, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) # [N, T, C] -> [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activation.permute(1, 0, 2), dim=-1) input_lengths = torch.tensor([T]).to(device) target_lengths = torch.randint(1, T, (1,)).to(device) targets = torch.randint(1, C - 1, (target_lengths.item(),)).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([targets.tolist()]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation.grad, k2_activation.grad, atol=1e-2)
def test_composition_equivalence(self): index = _generate_fsa_vec() index = k2.arc_sort(k2.connect(k2.remove_epsilon(index))) src = _generate_fsa_vec() replace = k2.replace_fsa(src, index, 1) replace = k2.top_sort(replace) f_fsa = _construct_f(src) f_fsa = k2.arc_sort(f_fsa) intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True) intersect = k2.invert(intersect) intersect = k2.top_sort(intersect) delattr(intersect, 'aux_labels') assert k2.is_rand_equivalent(replace, intersect, log_semiring=True, delta=1e-3)
def test_arc_sort(self): s = r''' 0 1 2 0 4 0 0 2 0 1 2 1 1 3 0 2 1 0 4 ''' fsa = k2.str_to_fsa(s) arc_map = k2.IntArray1.create_array_with_size(fsa.size2) k2.arc_sort(fsa, arc_map) expected_arc_indexes = torch.IntTensor([0, 3, 5, 6, 6, 6]) expected_arcs = torch.IntTensor([[0, 2, 0], [0, 4, 0], [0, 1, 2], [1, 3, 0], [1, 2, 1], [2, 1, 0]]) expected_arc_map = torch.IntTensor([2, 1, 0, 4, 3, 5]) self.assertTrue(torch.equal(fsa.indexes, expected_arc_indexes)) self.assertTrue(torch.equal(fsa.data, expected_arcs)) self.assertTrue(torch.equal(arc_map.data, expected_arc_map))
def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, path_to_seq_map: torch.Tensor) -> torch.Tensor: '''Compute AM scores of n-best lists (represented as word_fsas). Args: lats: An FsaVec, which is the output of `k2.intersect_dense_pruned`. It must have the attribute `lm_scores`. word_fsas_with_epsilon_loops: An FsaVec representing a n-best list. Note that it has been processed by `k2.add_epsilon_self_loops`. path_to_seq_map: A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to. path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). Returns: Return a 1-D torch.Tensor containing the AM scores of each path. `ans.numel() == word_fsas_with_epsilon_loops.shape[0]` ''' device = lats.device assert len(lats.shape) == 3 assert hasattr(lats, 'lm_scores') # k2.compose() currently does not support b_to_a_map. To void # replicating `lats`, we use k2.intersect_device here. # # lats has phone IDs as `labels` and word IDs as aux_labels, so we # need to invert it here. inverted_lats = k2.invert(lats) # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor) # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes) # Remove its `aux_labels` since it is not needed in the # following computation del inverted_lats.aux_labels inverted_lats = k2.arc_sort(inverted_lats) am_path_lats = _intersect_device(inverted_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) # NOTE: `k2.connect` and `k2.top_sort` support only CPU at present am_path_lats = k2.top_sort(k2.connect(am_path_lats.to('cpu'))).to(device) # The `scores` of every arc consists of `am_scores` and `lm_scores` am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores am_scores = am_path_lats.get_tot_scores(True, True) return am_scores
def create_decoding_graph(texts, graph, symbols): fsas = [] for text in texts: filter_text = [ i if i in symbols._sym2id else '<UNK>' for i in text.split(' ') ] word_ids = [symbols.get(i) for i in filter_text] fsa = k2.linear_fsa(word_ids) fsa = k2.arc_sort(fsa) decoding_graph = k2.intersect(fsa, graph).invert_() decoding_graph = k2.add_epsilon_self_loops(decoding_graph) fsas.append(decoding_graph) return k2.create_fsa_vec(fsas)
def test_case2(self): for device in self.devices: # (T, N, C) torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute( 1, 0, 2).to(device) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `c,c` targets = torch.tensor([3, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([3, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([7.355742931366]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def compile(self, targets: torch.Tensor, target_lengths: torch.Tensor) -> 'k2.Fsa': token_ids_list = [ t[:l].tolist() for t, l in zip(targets, target_lengths) ] # see https://github.com/k2-fsa/k2/issues/835 label_graph = k2.linear_fsa(token_ids_list).to(self.device) label_graph.aux_labels = label_graph.labels.clone() decoding_graphs = compose_with_self_loops(self.base_graph, label_graph) decoding_graphs = k2.arc_sort(decoding_graphs).to(self.device) # make sure the gradient is not accumulated decoding_graphs.requires_grad_(False) return decoding_graphs
def __init__(self, lexicon: Lexicon, device: torch.device, oov: str = '<UNK>'): ''' Args: L_inv: Its labels are words, while its aux_labels are phones. phones: The phone symbol table. words: The word symbol table. oov: Out of vocabulary word. ''' self.lexicon = lexicon L_inv = self.lexicon.L_inv.to(device) if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: L_inv = k2.arc_sort(L_inv) assert L_inv.requires_grad is False assert oov in self.lexicon.words self.L_inv = L_inv self.oov_id = self.lexicon.words[oov] self.oov = oov self.device = device phone_symbols = get_phone_symbols(self.lexicon.phones) phone_symbols_with_blank = [0] + phone_symbols ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) assert ctc_topo.requires_grad is False self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())