コード例 #1
0
ファイル: determinize_test.py プロジェクト: zhu-han/k2
 def test1(self):
     s = '''
         0 4 1 1
         0 1 1 1
         1 2 2 2
         1 3 3 3
         2 7 1 4
         3 7 1 5
         4 6 1 2
         4 6 1 3
         4 5 1 3
         4 8 -1 2
         5 8 -1 4
         6 8 -1 3
         7 8 -1 5
         8
     '''
     fsa = k2.Fsa.from_str(s)
     prop = fsa.properties
     self.assertFalse(
         prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
     dest = k2.determinize(fsa)
     log_semiring = False
     self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))
     arc_sorted = k2.arc_sort(dest)
     prop = arc_sorted.properties
     self.assertTrue(
         prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
コード例 #2
0
 def test_random(self):
     while True:
         fsa = k2.random_fsa(max_symbol=20,
                             min_num_arcs=50,
                             max_num_arcs=500)
         fsa = k2.arc_sort(k2.connect(k2.remove_epsilon(fsa)))
         prob = fsa.properties
         # we need non-deterministic fsa
         if not prob & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC:
             break
     log_semiring = False
     # test weight pushing tropical
     dest_max = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing)
     self.assertTrue(
         k2.is_rand_equivalent(fsa, dest_max, log_semiring, delta=1e-3))
     # test weight pushing log
     dest_log = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing)
     self.assertTrue(
         k2.is_rand_equivalent(fsa, dest_log, log_semiring, delta=1e-3))
コード例 #3
0
 def test1(self):
     s = '''
         0 4 1 1
         0 1 1 1
         1 2 2 2
         1 3 3 3
         2 7 1 4
         3 7 1 5
         4 6 1 2
         4 6 1 3
         4 5 1 3
         4 8 -1 2
         5 8 -1 4
         6 8 -1 3
         7 8 -1 5
         8
     '''
     fsa = k2.Fsa.from_str(s)
     prop = fsa.properties
     self.assertFalse(
         prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
     dest = k2.determinize(fsa)
     log_semiring = False
     self.assertTrue(k2.is_rand_equivalent(fsa, dest, log_semiring))
     arc_sorted = k2.arc_sort(dest)
     prop = arc_sorted.properties
     self.assertTrue(
         prop & k2.fsa_properties.ARC_SORTED_AND_DETERMINISTIC != 0)
     # test weight pushing tropical
     dest_max = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kTropicalWeightPushing)
     self.assertTrue(k2.is_rand_equivalent(dest, dest_max, log_semiring))
     # test weight pushing log
     dest_log = k2.determinize(
         fsa, k2.DeterminizeWeightPushingType.kLogWeightPushing)
     self.assertTrue(k2.is_rand_equivalent(dest, dest_log, log_semiring))
コード例 #4
0
ファイル: graph.py プロジェクト: juxiangyu/snowfall
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    LG = k2.add_epsilon_self_loops(LG)
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
コード例 #5
0
def compile_LG(L: Fsa, G: Fsa, ctc_topo: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo:  CTC topology fst, in which when 0 appears on the left side, it represents
                   the blank symbol; when it appears on the right side,
                   it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    LG = k2.compose(ctc_topo, LG, inner_labels='phones')

    logging.info("Connecting LG")
    LG = k2.connect(LG)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)
    logging.info(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
コード例 #6
0
ファイル: graph.py プロジェクト: zhichaowang/snowfall
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int,
                aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        H:  An ``Fsa`` that represents a specific topology used to convert the network
            outputs to a sequence of phones.
            Typically, it's a CTC topology fst, in which when 0 appears on the left
            side, it represents the blank symbol; when it appears on the right side,
            it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    HLG = k2.compose(H, LG, inner_labels='phones')

    logging.info("Connecting LG")
    HLG = k2.connect(HLG)

    logging.info("Arc sorting LG")
    HLG = k2.arc_sort(HLG)
    logging.info(
        f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )

    # Attach a new attribute `lm_scores` so that we can recover
    # the `am_scores` later.
    # The scores on an arc consists of two parts:
    #  scores = am_scores + lm_scores
    # NOTE: we assume that both kinds of scores are in log-space.
    HLG.lm_scores = HLG.scores.clone()
    return HLG
コード例 #7
0
ファイル: graph.py プロジェクト: yaguanghu/snowfall
def compile_LG(L: Fsa, G: Fsa, ctc_topo_inv: Fsa,
               labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo_inv:  Epsilons are in `aux_labels` and `labels` contain phone IDs.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.debug("Removing epsilons")
    LG = k2.remove_epsilons_iterative_tropical(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)

    logging.debug("Composing")
    LG = k2.compose(ctc_topo_inv, LG)

    logging.debug("Connecting")
    LG = k2.connect(LG)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG