Exemple #1
0
    def test_no_repeated(self):
        # standard ctc topo and modified ctc topo
        # should be equivalent if there are no
        # repeated neighboring symbols in the transcript
        max_token = 3
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 3])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        input1 = k2.linear_fsa([1, 1, 1, 0, 0, 2, 2, 3, 3])
        input2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 3, 3])
        inputs = [input1, input2]
        for i in inputs:
            lattice1 = k2.intersect(standard_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice2 = k2.intersect(modified_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice1 = k2.connect(lattice1)
            lattice2 = k2.connect(lattice2)

            aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
            aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
            aux_labels1 = aux_labels1[:-1]  # remove -1
            aux_labels2 = aux_labels2[:-1]
            assert torch.all(torch.eq(aux_labels1, aux_labels2))
            assert torch.all(torch.eq(aux_labels2, torch.tensor([1, 2, 3])))
Exemple #2
0
    def test_with_repeated(self):
        max_token = 2
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 2])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        # There is a blank separating 2 in the input
        # so standard and modified ctc topo should be equivalent
        input = k2.linear_fsa([1, 1, 2, 2, 0, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)

        aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels1 = aux_labels1[:-1]  # remove -1
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, aux_labels2))
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))

        # There are no blanks separating 2 in the input.
        # The standard ctc topo requires that there must be a blank
        # separating 2, so lattice1 in the following is empty
        input = k2.linear_fsa([1, 1, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)
        assert lattice1.num_arcs == 0

        # Since there are two 2s in the input and there are also two 2s
        # in the transcript, the final output contains only one path.
        # If there were more than two 2s in the input, the output
        # would contain more than one path
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))
Exemple #3
0
    def test_compose(self):
        s = '''
            0 1 11 1 1.0
            0 2 12 2 2.5
            1 3 -1 -1 0
            2 3 -1 -1 2.5
            3
        '''
        a_fsa = k2.Fsa.from_str(s, num_aux_labels=1).requires_grad_(True)

        s = '''
            0 1 1 1 1.0
            0 2 2 3 3.0
            1 2 3 2 2.5
            2 3 -1 -1 2.0
            3
        '''
        b_fsa = k2.Fsa.from_str(s, num_aux_labels=1).requires_grad_(True)

        ans = k2.compose(a_fsa, b_fsa, inner_labels='inner')
        ans = k2.connect(ans)

        ans = k2.create_fsa_vec([ans])

        scores = ans.get_tot_scores(log_semiring=True, use_double_scores=False)
        # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad`
        # are computed using GTN.
        # See https://bit.ly/3heLAJq
        assert scores.item() == 10
        scores.backward()
        assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.]))
Exemple #4
0
    def test_ragged_aux_labels(self):
        s1 = '''
            0 1 1 0.1
            0 2 5 0.6
            1 2 3 0.3
            2 3 3 0.5
            2 4 2 0.6
            3 5 -1 0.7
            4 5 -1 0.8
            5
        '''

        s2 = '''
            0 0 2 1 1
            0 1 4 3 2
            0 1 6 2 2
            0 2 -1 -1 0
            1 1 2 5 3
            1 2 -1 -1 4
            2
        '''
        # https://git.io/JqNok
        fsa1 = k2.Fsa.from_str(s1)
        fsa1.aux_labels = k2.RaggedInt('[[2] [2 4] [5] [3] [2] [-1] [-1]]')

        # https://git.io/JqNaJ
        fsa2 = k2.Fsa.from_str(s2, num_aux_labels=1)

        # https://git.io/JqNon
        ans = k2.connect(k2.compose(fsa1, fsa2, inner_labels='phones'))

        assert torch.all(torch.eq(ans.labels, torch.tensor([5, 0, 2, -1])))
        assert torch.all(torch.eq(ans.phones, torch.tensor([2, 4, 2, -1])))
        assert str(ans.aux_labels) == str(k2.RaggedInt('[[1] [3] [5] [-1]]'))
Exemple #5
0
    def test_compose_inner_labels(self):
        s1 = '''
            0 1 1 2 0.1
            0 2 0 2 0.2
            1 3 3 5 0.3
            2 3 5 4 0.4
            3 4 3 3 0.5
            3 5 2 2 0.6
            4 6 -1 -1 0.7
            5 6 -1 -1 0.8
            6
        '''

        s2 = '''
            0 0 2 1 1
            0 1 4 3 2
            0 1 6 2 2
            0 2 -1 -1 0
            1 1 2 5 3
            1 2 -1 -1 4
            2
        '''

        # https://git.io/JqN2j
        fsa1 = k2.Fsa.from_str(s1, num_aux_labels=1)

        # https://git.io/JqNaJ
        fsa2 = k2.Fsa.from_str(s2, num_aux_labels=1)

        # https://git.io/JqNaT
        ans = k2.connect(k2.compose(fsa1, fsa2, inner_labels='phones'))

        assert torch.all(torch.eq(ans.labels, torch.tensor([0, 5, 2, -1])))
        assert torch.all(torch.eq(ans.phones, torch.tensor([2, 4, 2, -1])))
        assert torch.all(torch.eq(ans.aux_labels, torch.tensor([1, 3, 5, -1])))
Exemple #6
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
        '''
        assert P.is_cpu()

        ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_()
        ctc_topo_P = k2.connect(ctc_topo_P)

        num_graphs = k2.create_fsa_vec(
            [self.compile_one_and_cache(text) for text in texts])

        num = k2.compose(ctc_topo_P, num_graphs)
        num = k2.connect(num)
        num = k2.arc_sort(num)

        den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts))

        return num, den
Exemple #7
0
    def compile(self,
                texts: Iterable[str],
                P: k2.Fsa,
                replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
          replicate_den:
            If True, the returned den_graph is replicated to match the number
            of FSAs in the returned num_graph; if False, the returned den_graph
            contains only a single FSA
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph` if replicate_den is True; otherwise, it
              is an FsaVec containing only a single FSA.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)
        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False)
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        if replicate_den:
            indexes = torch.zeros(len(texts),
                                  dtype=torch.int32,
                                  device=self.device)
            den = k2.index_fsa(ctc_topo_P_vec, indexes)
        else:
            den = ctc_topo_P_vec

        return num, den
Exemple #8
0
 def compile_one_and_cache(self, text: str) -> k2.Fsa:
     tokens = (token if token in self.words else self.oov
               for token in text.split(' '))
     word_ids = [self.words[token] for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.arc_sort(decoding_graph)
     decoding_graph = k2.compose(self.ctc_topo, decoding_graph)
     decoding_graph = k2.connect(decoding_graph)
     return decoding_graph
Exemple #9
0
    def test_case1(self):
        for device in self.devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)

            torch_loss.backward()
            k2_loss.backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #10
0
    def test_case3(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            expected_loss = torch.tensor([4.938850402832],
                                         device=device) / target_lengths

            assert torch.allclose(torch_loss, k2_loss)
            assert torch.allclose(torch_loss, expected_loss)

            torch_loss.backward()
            k2_loss.backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #11
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph, decoding_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.
              It is the result of compose(ctc_topo, P, L, transcript)

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
              It is the result of compose(ctc_topo, P).

            - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G)
              Note that it is a single Fsa, not an FsaVec.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()
        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)

        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False,
                         inner_labels='phones')
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        indexes = torch.zeros(len(texts),
                              dtype=torch.int32,
                              device=self.device)
        den = k2.index_fsa(ctc_topo_P_vec, indexes)

        return num, den, self.decoding_graph
Exemple #12
0
    def test_random_case1(self):
        # 1 sequence
        for device in self.devices:
            T = torch.randint(10, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1, )).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(), )).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo = k2.ctc_topo(C - 1)
            linear_fsa = k2.linear_fsa([targets.tolist()])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (k2_loss * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Exemple #13
0
    def test_compose(self):
        s = '''
            0 1 11 1 1.0
            0 2 12 2 2.5
            1 3 -1 -1 0
            2 3 -1 -1 2.5
            3
        '''
        a_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        s = '''
            0 1 1 1 1.0
            0 2 2 3 3.0
            1 2 3 2 2.5
            2 3 -1 -1 2.0
            3
        '''
        b_fsa = k2.Fsa.from_str(s).requires_grad_(True)

        ans = k2.compose(a_fsa, b_fsa, inner_labels='inner')
        ans = k2.connect(ans)

        # Convert a single FSA to a FsaVec.
        # It will retain `requires_grad_` of `ans`.
        ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs])

        scores = k2.get_tot_scores(ans,
                                   log_semiring=True,
                                   use_double_scores=False)
        # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad`
        # are computed using GTN.
        # See https://bit.ly/3heLAJq
        assert scores.item() == 10
        scores.backward()
        assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.]))
        print(ans)
Exemple #14
0
def compile_HLG(L: Fsa, G: Fsa, H: Fsa, labels_disambig_id_start: int,
                aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        H:  An ``Fsa`` that represents a specific topology used to convert the network
            outputs to a sequence of phones.
            Typically, it's a CTC topology fst, in which when 0 appears on the left
            side, it represents the blank symbol; when it appears on the right side,
            it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    HLG = k2.compose(H, LG, inner_labels='phones')

    logging.info("Connecting LG")
    HLG = k2.connect(HLG)

    logging.info("Arc sorting LG")
    HLG = k2.arc_sort(HLG)
    logging.info(
        f'LG is arc sorted: {(HLG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )

    # Attach a new attribute `lm_scores` so that we can recover
    # the `am_scores` later.
    # The scores on an arc consists of two parts:
    #  scores = am_scores + lm_scores
    # NOTE: we assume that both kinds of scores are in log-space.
    HLG.lm_scores = HLG.scores.clone()
    return HLG
Exemple #15
0
    def visualize_ctc_topo():
        '''This function shows how to visualize
        standard/modified ctc topologies. It's for
        demonstration only, not for testing.
        '''
        max_token = 2
        labels_sym = k2.SymbolTable.from_str('''
            <blk> 0
            z 1
            o 2
        ''')
        aux_labels_sym = k2.SymbolTable.from_str('''
            z 1
            o 2
        ''')

        word_sym = k2.SymbolTable.from_str('''
            zoo 1
        ''')

        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        standard.labels_sym = labels_sym
        standard.aux_labels_sym = aux_labels_sym

        modified.labels_sym = labels_sym
        modified.aux_labels_sym = aux_labels_sym

        standard.draw('standard_topo.svg', title='standard CTC topo')
        modified.draw('modified_topo.svg', title='modified CTC topo')
        fsa = k2.linear_fst([1, 2, 2], [1, 0, 0])
        fsa.labels_sym = labels_sym
        fsa.aux_labels_sym = word_sym
        fsa.draw('transcript.svg', title='transcript')

        standard_graph = k2.compose(standard, fsa)
        modified_graph = k2.compose(modified, fsa)
        standard_graph.draw('standard_graph.svg', title='standard graph')
        modified_graph.draw('modified_graph.svg', title='modified graph')

        # z z <blk> <blk> o o <blk> o <blk>
        inputs = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 2, 0])
        inputs.labels_sym = labels_sym
        inputs.draw('inputs.svg', title='inputs')
        standard_lattice = k2.intersect(standard_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        standard_lattice.draw('standard_lattice.svg', title='standard lattice')

        modified_lattice = k2.intersect(modified_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        modified_lattice = k2.connect(modified_lattice)
        modified_lattice.draw('modified_lattice.svg', title='modified lattice')

        # z z <blk> <blk> o o o <blk>
        inputs2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 2, 0])
        inputs2.labels_sym = labels_sym
        inputs2.draw('inputs2.svg', title='inputs2')
        standard_lattice2 = k2.intersect(standard_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        standard_lattice2 = k2.connect(standard_lattice2)
        # It's empty since the topo requires that there must be a blank
        # between the two o's in zoo
        assert standard_lattice2.num_arcs == 0
        standard_lattice2.draw('standard_lattice2.svg',
                               title='standard lattice2')

        modified_lattice2 = k2.intersect(modified_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        modified_lattice2 = k2.connect(modified_lattice2)
        modified_lattice2.draw('modified_lattice2.svg',
                               title='modified lattice2')
Exemple #16
0
def compile_LG(L: Fsa, G: Fsa, ctc_topo_inv: Fsa,
               labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo_inv:  Epsilons are in `aux_labels` and `labels` contain phone IDs.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.debug("Removing epsilons")
    LG = k2.remove_epsilons_iterative_tropical(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)

    logging.debug("Composing")
    LG = k2.compose(ctc_topo_inv, LG)

    logging.debug("Connecting")
    LG = k2.connect(LG)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Exemple #17
0
    def test_case4(self):
        for device in self.devices:
            # put case3, case2 and case1 into a batch
            torch_activation_1 = torch.tensor(
                [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True)

            torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to(
                torch.float32).to(device).requires_grad_(True)

            torch_activation_3 = torch.tensor([
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]).to(device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)
            k2_activation_3 = torch_activation_3.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_3, torch_activation_2, torch_activation_1],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_3, k2_activation_2, k2_activation_1],
                batch_first=True,
                padding_value=0)

            # [[b,c], [c,c], [a]]
            targets = torch.tensor([2, 3, 3, 3, 1]).to(device)
            input_lengths = torch.tensor([3, 3, 1]).to(device)
            target_lengths = torch.tensor([2, 2, 1]).to(device)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)  # (T, N, C)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='sum')

            expected_loss = torch.tensor(
                [4.938850402832, 7.355742931366, 1.6094379425049]).sum()

            assert torch.allclose(torch_loss, expected_loss.to(device))

            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            supervision_segments = torch.tensor(
                [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            # [ [b, c], [c, c], [a]]
            linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='sum',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)

            scale = torch.tensor([1., -2, 3.5]).to(device)
            (torch_loss * scale).sum().backward()
            (k2_loss * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad)
            assert torch.allclose(torch_activation_3.grad,
                                  k2_activation_3.grad)
Exemple #18
0
def compose_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa':
    """Composition helper function.
    """
    aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device)
    return k2.compose(base_graph, aux_graph_with_self_loops, treat_epsilons_specially=False, inner_labels="phones",)
Exemple #19
0
def compile_LG(L: Fsa, G: Fsa, ctc_topo: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo:  CTC topology fst, in which when 0 appears on the left side, it represents
                   the blank symbol; when it appears on the right side,
                   it indicates an epsilon.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L = k2.arc_sort(L)
    G = k2.arc_sort(G)
    logging.info("Intersecting L and G")
    LG = k2.compose(L, G)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting L*G")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.info("Removing epsilons")
    LG = k2.remove_epsilon(LG)
    logging.info(f'LG shape = {LG.shape}')
    logging.info("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.info(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)

    logging.info("Composing ctc_topo LG")
    LG = k2.compose(ctc_topo, LG, inner_labels='phones')

    logging.info("Connecting LG")
    LG = k2.connect(LG)

    logging.info("Arc sorting LG")
    LG = k2.arc_sort(LG)
    logging.info(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Exemple #20
0
    def test_random_case2(self):
        # 2 sequences
        for device in self.devices:
            T1 = torch.randint(10, 200, (1, )).item()
            T2 = torch.randint(9, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            if T1 < T2:
                T1, T2 = T2, T1

            torch_activation_1 = torch.rand((T1, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)
            torch_activation_2 = torch.rand((T2, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_1, torch_activation_2],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_1, k2_activation_2],
                batch_first=True,
                padding_value=0)

            target_length1 = torch.randint(1, T1, (1, )).item()
            target_length2 = torch.randint(1, T2, (1, )).item()

            target_lengths = torch.tensor([target_length1,
                                           target_length2]).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.sum(), )).to(device)

            # [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)
            input_lengths = torch.tensor([T1, T2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            assert T1 >= T2
            supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]],
                                                dtype=torch.int32)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo = k2.ctc_topo(C - 1)
            linear_fsa = k2.linear_fsa([
                targets[:target_length1].tolist(),
                targets[target_length1:].tolist()
            ])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (k2_loss * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad,
                                  atol=1e-2)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad,
                                  atol=1e-2)