def test1(self): s = ''' 0 4 1 1 0 1 1 1 1 2 2 2 1 3 3 3 2 7 1 4 3 7 1 5 4 6 1 2 4 6 1 3 4 5 1 3 4 8 -1 2 5 8 -1 4 6 8 -1 3 7 8 -1 5 8 ''' fsa = k2.Fsa.from_str(s) prop = fsa.properties self.assertFalse( prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0) dest = k2.determinize(fsa) log_semiring = False self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) arc_sorted = k2.arc_sort(dest) prop = arc_sorted.properties self.assertTrue( prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
def test_random(self): while True: fsa = k2.random_fsa(max_symbol=20, min_num_arcs=50, max_num_arcs=500) fsa = k2.arc_sort(k2.connect(k2.remove_epsilon(fsa))) prob = fsa.properties # we need non-deterministic fsa if not prob & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC: break log_semiring = False # test weight pushing tropical dest_max = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing) self.assertTrue( k2.is_rand_equivalent(fsa, dest_max, log_semiring, delta=1e-3)) # test weight pushing log dest_log = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing) self.assertTrue( k2.is_rand_equivalent(fsa, dest_log, log_semiring, delta=1e-3))
def test1(self): s = ''' 0 4 1 1 0 1 1 1 1 2 2 2 1 3 3 3 2 7 1 4 3 7 1 5 4 6 1 2 4 6 1 3 4 5 1 3 4 8 -1 2 5 8 -1 4 6 8 -1 3 7 8 -1 5 8 ''' fsa = k2.Fsa.from_str(s) prop = fsa.properties self.assertFalse( prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0) dest = k2.determinize(fsa) log_semiring = False self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring)) arc_sorted = k2.arc_sort(dest) prop = arc_sorted.properties self.assertTrue( prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0) # test weight pushing tropical dest_max = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing) self.assertTrue(k2.is_rand_equivalent(dest, dest_max, log_semiring)) # test weight pushing log dest_log = k2.determinize( fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing) self.assertTrue(k2.is_rand_equivalent(dest, dest_log, log_semiring))
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 LG = k2.add_epsilon_self_loops(LG) LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def compile_LG(L: Fsa, G: Fsa, ctc_topo: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. ctc_topo: CTC topology fst, in which when 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L = k2.arc_sort(L) G = k2.arc_sort(G) logging.info("Intersecting L and G") LG = k2.compose(L, G) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting L*G") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Determinizing L*G") LG = k2.determinize(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting det(L*G)") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.info("Removing epsilons") LG = k2.remove_epsilon(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info("Composing ctc_topo LG") LG = k2.compose(ctc_topo, LG, inner_labels='phones') logging.info("Connecting LG") LG = k2.connect(LG) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. H: An ``Fsa`` that represents a specific topology used to convert the network outputs to a sequence of phones. Typically, it's a CTC topology fst, in which when 0 appears on the left side, it represents the blank symbol; when it appears on the right side, it indicates an epsilon. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L = k2.arc_sort(L) G = k2.arc_sort(G) logging.info("Intersecting L and G") LG = k2.compose(L, G) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting L*G") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Determinizing L*G") LG = k2.determinize(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting det(L*G)") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.info("Removing epsilons") LG = k2.remove_epsilon(LG) logging.info(f'LG shape = {LG.shape}') logging.info("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.info(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.info("Arc sorting LG") LG = k2.arc_sort(LG) logging.info("Composing ctc_topo LG") HLG = k2.compose(H, LG, inner_labels='phones') logging.info("Connecting LG") HLG = k2.connect(HLG) logging.info("Arc sorting LG") HLG = k2.arc_sort(HLG) logging.info( f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) # Attach a new attribute `lm_scores` so that we can recover # the `am_scores` later. # The scores on an arc consists of two parts: # scores = am_scores + lm_scores # NOTE: we assume that both kinds of scores are in log-space. HLG.lm_scores = HLG.scores.clone() return HLG
def compile_LG(L: Fsa, G: Fsa, ctc_topo_inv: Fsa, labels_disambig_id_start: int, aux_labels_disambig_id_start: int) -> Fsa: """ Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``. Involves arc sorting, intersection, determinization, removal of disambiguation symbols and adding epsilon self-loops. Args: L: An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols`` and words as ``aux_symbols``. G: An ``Fsa`` that represents the language model (G), i.e. it's an acceptor with words as ``symbols``. ctc_topo_inv: Epsilons are in `aux_labels` and `labels` contain phone IDs. labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the phonetic alphabet. aux_labels_disambig_id_start: An integer ID corresponding to the first disambiguation symbol in the words vocabulary. :return: """ L_inv = k2.arc_sort(L.invert_()) G = k2.arc_sort(G) logging.debug("Intersecting L and G") LG = k2.intersect(L_inv, G) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting L*G") LG = k2.connect(LG).invert_() logging.debug(f'LG shape = {LG.shape}') logging.debug("Determinizing L*G") LG = k2.determinize(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting det(L*G)") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Removing disambiguation symbols on L*G") LG.labels[LG.labels >= labels_disambig_id_start] = 0 if isinstance(LG.aux_labels, torch.Tensor): LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0 else: LG.aux_labels.values()[ LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0 logging.debug("Removing epsilons") LG = k2.remove_epsilons_iterative_tropical(LG) logging.debug(f'LG shape = {LG.shape}') logging.debug("Connecting rm-eps(det(L*G))") LG = k2.connect(LG) logging.debug(f'LG shape = {LG.shape}') LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0) logging.debug("Arc sorting") LG = k2.arc_sort(LG) logging.debug("Composing") LG = k2.compose(ctc_topo_inv, LG) logging.debug("Connecting") LG = k2.connect(LG) logging.debug("Arc sorting") LG = k2.arc_sort(LG) logging.debug( f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}' ) return LG