Exemple #1
0
    def test_no_repeated(self):
        # standard ctc topo and modified ctc topo
        # should be equivalent if there are no
        # repeated neighboring symbols in the transcript
        max_token = 3
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 3])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        input1 = k2.linear_fsa([1, 1, 1, 0, 0, 2, 2, 3, 3])
        input2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 3, 3])
        inputs = [input1, input2]
        for i in inputs:
            lattice1 = k2.intersect(standard_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice2 = k2.intersect(modified_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice1 = k2.connect(lattice1)
            lattice2 = k2.connect(lattice2)

            aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
            aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
            aux_labels1 = aux_labels1[:-1]  # remove -1
            aux_labels2 = aux_labels2[:-1]
            assert torch.all(torch.eq(aux_labels1, aux_labels2))
            assert torch.all(torch.eq(aux_labels2, torch.tensor([1, 2, 3])))
Exemple #2
0
    def test(self):
        # for the symbol table
        # <eps> 0
        # a 0
        # b 1
        # c 2

        # an FSA that recognizes a+(b|c)
        s = '''
            0 1 1 0.1
            1 1 1 0.2
            1 2 2 0.3
            1 3 3 0.4
            2 4 -1 0.5
            3 4 -1 0.6
            5
        '''
        a_fsa = k2.Fsa.from_str(s)
        a_fsa.requires_grad_(True)

        # an FSA that recognizes ab
        s = '''
            0 1 1 10
            1 2 2 20
            2 3 -1 30
            3
        '''
        b_fsa = k2.Fsa.from_str(s)
        b_fsa.requires_grad_(True)

        fsa = k2.intersect(a_fsa, b_fsa)
        assert len(fsa.shape) == 2
        actual_str = k2.to_str(fsa)
        expected_str = '\n'.join(
            ['0 1 1 10.1', '1 2 2 20.3', '2 3 -1 30.5', '3'])
        assert actual_str.strip() == expected_str

        loss = fsa.scores.sum()
        loss.backward()
        # arc 0, 2, and 4 of a_fsa are kept in the final intersected FSA
        assert torch.allclose(
            a_fsa.scores.grad,
            torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.float32))

        assert torch.allclose(b_fsa.scores.grad,
                              torch.tensor([1, 1, 1], dtype=torch.float32))

        # if any of the input FSA is an FsaVec,
        # the outupt FSA is also an FsaVec.
        a_fsa.scores.grad = None
        b_fsa.scores.grad = None
        a_fsa = k2.create_fsa_vec([a_fsa])
        fsa = k2.intersect(a_fsa, b_fsa)
        assert len(fsa.shape) == 3
Exemple #3
0
    def test_treat_epsilon_specially_false(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available() and k2.with_cuda:
            devices.append(torch.device('cuda'))

        for device in devices:
            # a_fsa recognizes `(0|1)2*`
            s1 = '''
                0 1 0 0.1
                0 1 1 0.2
                1 1 2 0.3
                1 2 -1 0.4
                2
            '''
            a_fsa = k2.Fsa.from_str(s1).to(device)
            a_fsa.requires_grad_(True)

            # b_fsa recognizes `1|2`
            s2 = '''
                0 1 1 1
                0 1 2 2
                1 2 -1 3
                2
            '''
            b_fsa = k2.Fsa.from_str(s2).to(device)
            b_fsa.requires_grad_(True)

            # fsa recognizes `1`
            fsa = k2.intersect(a_fsa, b_fsa, treat_epsilons_specially=False)
            assert len(fsa.shape) == 2
            actual_str = k2.to_str_simple(fsa)
            expected_str = '\n'.join(['0 1 1 1.2', '1 2 -1 3.4', '2'])
            assert actual_str.strip() == expected_str

            loss = fsa.scores.sum()
            (-loss).backward()
            # arc 1 and 3 of a_fsa are kept in the final intersected FSA
            assert torch.allclose(a_fsa.grad,
                                  torch.tensor([0, -1, 0, -1]).to(a_fsa.grad))

            # arc 0 and 2 of b_fsa are kept in the final intersected FSA
            assert torch.allclose(b_fsa.grad,
                                  torch.tensor([-1, 0, -1]).to(b_fsa.grad))

            # if any of the input FSA is an FsaVec,
            # the outupt FSA is also an FsaVec.
            a_fsa.scores.grad = None
            b_fsa.scores.grad = None
            a_fsa = k2.create_fsa_vec([a_fsa])
            fsa = k2.intersect(a_fsa, b_fsa, treat_epsilons_specially=False)
            assert len(fsa.shape) == 3
Exemple #4
0
    def test_treat_epsilon_specially_true(self):
        # this version works only on CPU and requires
        # arc-sorted inputs
        # a_fsa recognizes `(1|3)?2*`
        s1 = '''
            0 1 3 0.0
            0 1 1 0.2
            0 1 0 0.1
            1 1 2 0.3
            1 2 -1 0.4
            2
        '''
        a_fsa = k2.Fsa.from_str(s1)
        a_fsa.requires_grad_(True)

        # b_fsa recognizes `1|2|5`
        s2 = '''
            0 1 5 0
            0 1 1 1
            0 1 2 2
            1 2 -1 3
            2
        '''
        b_fsa = k2.Fsa.from_str(s2)
        b_fsa.requires_grad_(True)

        # fsa recognizes 1|2
        fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa))
        assert len(fsa.shape) == 2
        actual_str = k2.to_str_simple(fsa)
        expected_str = '\n'.join(
            ['0 1 0 0.1', '0 2 1 1.2', '1 2 2 2.3', '2 3 -1 3.4', '3'])
        assert actual_str.strip() == expected_str

        loss = fsa.scores.sum()
        (-loss).backward()
        # arc 1, 2, 3, and 4 of a_fsa are kept in the final intersected FSA
        assert torch.allclose(a_fsa.grad,
                              torch.tensor([0, -1, -1, -1, -1]).to(a_fsa.grad))

        # arc 1, 2, and 3 of b_fsa are kept in the final intersected FSA
        assert torch.allclose(b_fsa.grad,
                              torch.tensor([0, -1, -1, -1]).to(b_fsa.grad))

        # if any of the input FSA is an FsaVec,
        # the outupt FSA is also an FsaVec.
        a_fsa.scores.grad = None
        b_fsa.scores.grad = None
        a_fsa = k2.create_fsa_vec([a_fsa])
        fsa = k2.intersect(k2.arc_sort(a_fsa), k2.arc_sort(b_fsa))
        assert len(fsa.shape) == 3
Exemple #5
0
    def test_with_repeated(self):
        max_token = 2
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 2])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        # There is a blank separating 2 in the input
        # so standard and modified ctc topo should be equivalent
        input = k2.linear_fsa([1, 1, 2, 2, 0, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)

        aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels1 = aux_labels1[:-1]  # remove -1
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, aux_labels2))
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))

        # There are no blanks separating 2 in the input.
        # The standard ctc topo requires that there must be a blank
        # separating 2, so lattice1 in the following is empty
        input = k2.linear_fsa([1, 1, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)
        assert lattice1.num_arcs == 0

        # Since there are two 2s in the input and there are also two 2s
        # in the transcript, the final output contains only one path.
        # If there were more than two 2s in the input, the output
        # would contain more than one path
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))
Exemple #6
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
        '''
        assert P.is_cpu()

        ctc_topo_P = k2.intersect(self.ctc_topo, P).invert_()
        ctc_topo_P = k2.connect(ctc_topo_P)

        num_graphs = k2.create_fsa_vec(
            [self.compile_one_and_cache(text) for text in texts])

        num = k2.compose(ctc_topo_P, num_graphs)
        num = k2.connect(num)
        num = k2.arc_sort(num)

        den = k2.create_fsa_vec([ctc_topo_P.detach()] * len(texts))

        return num, den
Exemple #7
0
    def test_case1(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #8
0
    def build_num_graphs(self, texts: List[str]) -> k2.Fsa:
        '''Convert transcript to an Fsa with the help of lexicon
        and word symbol table.

        Args:
          texts:
            Each element is a transcript containing words separated by spaces.
            For instance, it may be 'HELLO SNOWFALL', which contains
            two words.

        Returns:
          Return an FST (FsaVec) corresponding to the transcript. Its `labels` are
          phone IDs and `aux_labels` are word IDs.
        '''
        word_ids_list = []
        for text in texts:
            word_ids = []
            for word in text.split(' '):
                if word in self.words:
                    word_ids.append(self.words[word])
                else:
                    word_ids.append(self.oov_id)
            word_ids_list.append(word_ids)

        fsa = k2.linear_fsa(word_ids_list, self.device)
        fsa = k2.add_epsilon_self_loops(fsa)
        num_graphs = k2.intersect(self.L_inv,
                                  fsa,
                                  treat_epsilons_specially=False).invert_()
        num_graphs = k2.arc_sort(num_graphs)
        return num_graphs
Exemple #9
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))

    transcripts = k2.create_fsa_vec([k2.linear_fsa(x) for x in ys])
    transcripts_lexicon = k2.intersect(transcripts, L_inv)
    transcripts_lexicon = k2.arc_sort(k2.connect(transcripts_lexicon))
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Exemple #10
0
 def compile_one_and_cache(self, text: str) -> Fsa:
     tokens = (token if token in self.vocab._sym2id else self.oov
               for token in text.split(' '))
     word_ids = [self.vocab.get(token) for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
     return decoding_graph
Exemple #11
0
    def compile(self,
                texts: Iterable[str],
                P: k2.Fsa,
                replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
          replicate_den:
            If True, the returned den_graph is replicated to match the number
            of FSAs in the returned num_graph; if False, the returned den_graph
            contains only a single FSA
        Returns:
          A tuple (num_graph, den_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph` if replicate_den is True; otherwise, it
              is an FsaVec containing only a single FSA.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)
        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False)
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        if replicate_den:
            indexes = torch.zeros(len(texts),
                                  dtype=torch.int32,
                                  device=self.device)
            den = k2.index_fsa(ctc_topo_P_vec, indexes)
        else:
            den = ctc_topo_P_vec

        return num, den
Exemple #12
0
def intersect_with_self_loops(base_graph: 'k2.Fsa', aux_graph: 'k2.Fsa') -> 'k2.Fsa':
    """Intersection helper function.
    """
    assert hasattr(base_graph, "aux_labels")
    assert not hasattr(aux_graph, "aux_labels")
    aux_graph_with_self_loops = k2.arc_sort(k2.add_epsilon_self_loops(aux_graph)).to(base_graph.device)
    result = k2.intersect(k2.arc_sort(base_graph), aux_graph_with_self_loops, treat_epsilons_specially=False,)
    setattr(result, "phones", result.labels)
    return result
Exemple #13
0
    def test_case3(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([4.938850402832]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #14
0
 def compile_one_and_cache(self, text: str) -> k2.Fsa:
     tokens = (token if token in self.words else self.oov
               for token in text.split(' '))
     word_ids = [self.words[token] for token in tokens]
     fsa = k2.linear_fsa(word_ids)
     decoding_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
     decoding_graph = k2.arc_sort(decoding_graph)
     decoding_graph = k2.compose(self.ctc_topo, decoding_graph)
     decoding_graph = k2.connect(decoding_graph)
     return decoding_graph
Exemple #15
0
    def test_random_case1(self):
        # 1 sequence
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            T = torch.randint(10, 100, (1,)).item()
            C = torch.randint(20, 30, (1,)).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1,)).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(),)).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([targets.tolist()])

            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Exemple #16
0
def create_decoding_graph(texts, L, symbols):
    word_ids_list = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        word_ids_list.append(word_ids)
    fsa = k2.linear_fsa(word_ids_list)
    decoding_graph = k2.intersect(fsa, L).invert_()
    decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
    return decoding_graph
Exemple #17
0
    def compile(self, texts: Iterable[str],
                P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]:
        '''Create numerator and denominator graphs from transcripts
        and the bigram phone LM.

        Args:
          texts:
            A list of transcripts. Within a transcript, words are
            separated by spaces.
          P:
            The bigram phone LM created by :func:`create_bigram_phone_lm`.
        Returns:
          A tuple (num_graph, den_graph, decoding_graph), where

            - `num_graph` is the numerator graph. It is an FsaVec with
              shape `(len(texts), None, None)`.
              It is the result of compose(ctc_topo, P, L, transcript)

            - `den_graph` is the denominator graph. It is an FsaVec with the same
              shape of the `num_graph`.
              It is the result of compose(ctc_topo, P).

            - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G)
              Note that it is a single Fsa, not an FsaVec.
        '''
        assert P.device == self.device
        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(self.ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()
        ctc_topo_P = k2.arc_sort(ctc_topo_P)

        num_graphs = self.build_num_graphs(texts)

        num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
            num_graphs)

        num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

        num = k2.compose(ctc_topo_P,
                         num_graphs_with_self_loops,
                         treat_epsilons_specially=False,
                         inner_labels='phones')
        num = k2.arc_sort(num)

        ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
        indexes = torch.zeros(len(texts),
                              dtype=torch.int32,
                              device=self.device)
        den = k2.index_fsa(ctc_topo_P_vec, indexes)

        return num, den, self.decoding_graph
Exemple #18
0
def create_decoding_graph(texts, graph, symbols):
    fsas = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        fsa = k2.linear_fsa(word_ids)
        fsa = k2.arc_sort(fsa)
        decoding_graph = k2.intersect(fsa, graph).invert_()
        decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
        fsas.append(decoding_graph)
    return k2.create_fsa_vec(fsas)
Exemple #19
0
    def __init__(self,
                 lexicon: Lexicon,
                 P: k2.Fsa,
                 device: torch.device,
                 oov: str = '<UNK>'):
        '''
        Args:
          L_inv:
            Its labels are words, while its aux_labels are phones.
          P:
            A phone bigram LM if the pronunciations in the lexicon are in phones;
            a word piece bigram if the pronunciations in the lexicon are word pieces.
          phones:
            The phone symbol table.
          words:
            The word symbol table.
          oov:
            Out of vocabulary word.
        '''
        self.lexicon = lexicon
        L_inv = self.lexicon.L_inv.to(device)
        P = P.to(device)

        if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0:
            L_inv = k2.arc_sort(L_inv)

        assert L_inv.requires_grad is False

        assert oov in self.lexicon.words

        self.L_inv = L_inv
        self.oov_id = self.lexicon.words[oov]
        self.oov = oov
        self.device = device

        phone_symbols = get_phone_symbols(self.lexicon.phones)
        phone_symbols_with_blank = [0] + phone_symbols

        ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
        assert ctc_topo.requires_grad is False

        ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())

        P_with_self_loops = k2.add_epsilon_self_loops(P)

        ctc_topo_P = k2.intersect(ctc_topo_inv,
                                  P_with_self_loops,
                                  treat_epsilons_specially=False).invert()

        self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
Exemple #20
0
    def test_case2(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute(
                1, 0, 2).to(device)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `c,c`
            targets = torch.tensor([3, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([3, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([7.355742931366]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Exemple #21
0
def compile_LG(L: Fsa, G: Fsa, labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    LG = k2.add_epsilon_self_loops(LG)
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Exemple #22
0
    def compile_one_and_cache(self, text: str) -> k2.Fsa:
        '''Convert transcript to an Fsa with the help of lexicon
        and word symbol table.

        Args:
          text:
            The transcript containing words separated by spaces.

        Returns:
          Return an FST corresponding to the transcript. Its `labels` are
          phone IDs and `aux_labels` are word IDs.
        '''
        tokens = (token if token in self.words else self.oov
                  for token in text.split(' '))
        word_ids = [self.words[token] for token in tokens]
        fsa = k2.linear_fsa(word_ids)
        num_graph = k2.connect(k2.intersect(fsa, self.L_inv)).invert_()
        num_graph = k2.arc_sort(num_graph)
        return num_graph
Exemple #23
0
    def test_composition_equivalence(self):
        index = _generate_fsa_vec()
        index = k2.arc_sort(k2.connect(k2.remove_epsilon(index)))

        src = _generate_fsa_vec()

        replace = k2.replace_fsa(src, index, 1)
        replace = k2.top_sort(replace)

        f_fsa = _construct_f(src)
        f_fsa = k2.arc_sort(f_fsa)
        intersect = k2.intersect(index, f_fsa, treat_epsilons_specially=True)
        intersect = k2.invert(intersect)
        intersect = k2.top_sort(intersect)
        delattr(intersect, 'aux_labels')

        assert k2.is_rand_equivalent(replace,
                                     intersect,
                                     log_semiring=True,
                                     delta=1e-3)
Exemple #24
0
def create_decoding_graph(texts, L, symbols):
    fsas = []
    for text in texts:
        filter_text = [
            i if i in symbols._sym2id else '<UNK>' for i in text.split(' ')
        ]
        word_ids = [symbols.get(i) for i in filter_text]
        fsa = k2.linear_fsa(word_ids)
        print("linear fsa is ", fsa)
        fsa = k2.arc_sort(fsa)
        print("linear fsa, arc-sorted, is ", fsa)
        print("begin")
        print(k2.is_arc_sorted(k2.get_properties(fsa)))
        decoding_graph = k2.intersect(fsa, L).invert_()
        print("linear fsa, composed, is ", fsa)
        print("decoding graph is ", decoding_graph)
        decoding_graph = k2.add_epsilon_self_loops(decoding_graph)
        print("decoding graph with self-loops is ", decoding_graph)
        fsas.append(decoding_graph)
    return k2.create_fsa_vec(fsas)
Exemple #25
0
def get_hierarchical_targets(ys: List[List[int]],
                             lexicon: k2.Fsa) -> List[Tensor]:
    """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts).

    Args:
        ys: Word level transcripts.
        lexicon: Its labels are words, while its aux_labels are phones.

    Returns:
        List[Tensor]: Phone level transcripts.

    """

    if lexicon is None:
        return ys
    else:
        L_inv = lexicon

    n_batch = len(ys)
    indices = torch.tensor(range(n_batch))
    device = L_inv.device

    transcripts = k2.create_fsa_vec(
        [k2.linear_fsa(x, device=device) for x in ys])
    transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts)

    transcripts_lexicon = k2.intersect(L_inv,
                                       transcripts_with_self_loops,
                                       treat_epsilons_specially=False)
    # Don't call invert_() above because we want to return phone IDs,
    # which is the `aux_labels` of transcripts_lexicon
    transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon)
    transcripts_lexicon = k2.top_sort(transcripts_lexicon)

    transcripts_lexicon = k2.shortest_path(transcripts_lexicon,
                                           use_double_scores=True)

    ys = get_texts(transcripts_lexicon, indices)
    ys = [torch.tensor(y) for y in ys]

    return ys
Exemple #26
0
def compile_LG(L: Fsa, G: Fsa, ctc_topo_inv: Fsa,
               labels_disambig_id_start: int,
               aux_labels_disambig_id_start: int) -> Fsa:
    """
    Creates a decoding graph using a lexicon fst ``L`` and language model fsa ``G``.
    Involves arc sorting, intersection, determinization, removal of disambiguation symbols
    and adding epsilon self-loops.

    Args:
        L:
            An ``Fsa`` that represents the lexicon (L), i.e. has phones as ``symbols``
                and words as ``aux_symbols``.
        G:
            An ``Fsa`` that represents the language model (G), i.e. it's an acceptor
            with words as ``symbols``.
        ctc_topo_inv:  Epsilons are in `aux_labels` and `labels` contain phone IDs.
        labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            phonetic alphabet.
        aux_labels_disambig_id_start:
            An integer ID corresponding to the first disambiguation symbol in the
            words vocabulary.
    :return:
    """
    L_inv = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    logging.debug("Intersecting L and G")
    LG = k2.intersect(L_inv, G)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting L*G")
    LG = k2.connect(LG).invert_()
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Determinizing L*G")
    LG = k2.determinize(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting det(L*G)")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Removing disambiguation symbols on L*G")
    LG.labels[LG.labels >= labels_disambig_id_start] = 0
    if isinstance(LG.aux_labels, torch.Tensor):
        LG.aux_labels[LG.aux_labels >= aux_labels_disambig_id_start] = 0
    else:
        LG.aux_labels.values()[
            LG.aux_labels.values() >= aux_labels_disambig_id_start] = 0
    logging.debug("Removing epsilons")
    LG = k2.remove_epsilons_iterative_tropical(LG)
    logging.debug(f'LG shape = {LG.shape}')
    logging.debug("Connecting rm-eps(det(L*G))")
    LG = k2.connect(LG)
    logging.debug(f'LG shape = {LG.shape}')
    LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)

    logging.debug("Composing")
    LG = k2.compose(ctc_topo_inv, LG)

    logging.debug("Connecting")
    LG = k2.connect(LG)

    logging.debug("Arc sorting")
    LG = k2.arc_sort(LG)
    logging.debug(
        f'LG is arc sorted: {(LG.properties & k2.fsa_properties.ARC_SORTED) != 0}'
    )
    return LG
Exemple #27
0
    def visualize_ctc_topo():
        '''This function shows how to visualize
        standard/modified ctc topologies. It's for
        demonstration only, not for testing.
        '''
        max_token = 2
        labels_sym = k2.SymbolTable.from_str('''
            <blk> 0
            z 1
            o 2
        ''')
        aux_labels_sym = k2.SymbolTable.from_str('''
            z 1
            o 2
        ''')

        word_sym = k2.SymbolTable.from_str('''
            zoo 1
        ''')

        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        standard.labels_sym = labels_sym
        standard.aux_labels_sym = aux_labels_sym

        modified.labels_sym = labels_sym
        modified.aux_labels_sym = aux_labels_sym

        standard.draw('standard_topo.svg', title='standard CTC topo')
        modified.draw('modified_topo.svg', title='modified CTC topo')
        fsa = k2.linear_fst([1, 2, 2], [1, 0, 0])
        fsa.labels_sym = labels_sym
        fsa.aux_labels_sym = word_sym
        fsa.draw('transcript.svg', title='transcript')

        standard_graph = k2.compose(standard, fsa)
        modified_graph = k2.compose(modified, fsa)
        standard_graph.draw('standard_graph.svg', title='standard graph')
        modified_graph.draw('modified_graph.svg', title='modified graph')

        # z z <blk> <blk> o o <blk> o <blk>
        inputs = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 2, 0])
        inputs.labels_sym = labels_sym
        inputs.draw('inputs.svg', title='inputs')
        standard_lattice = k2.intersect(standard_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        standard_lattice.draw('standard_lattice.svg', title='standard lattice')

        modified_lattice = k2.intersect(modified_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        modified_lattice = k2.connect(modified_lattice)
        modified_lattice.draw('modified_lattice.svg', title='modified lattice')

        # z z <blk> <blk> o o o <blk>
        inputs2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 2, 0])
        inputs2.labels_sym = labels_sym
        inputs2.draw('inputs2.svg', title='inputs2')
        standard_lattice2 = k2.intersect(standard_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        standard_lattice2 = k2.connect(standard_lattice2)
        # It's empty since the topo requires that there must be a blank
        # between the two o's in zoo
        assert standard_lattice2.num_arcs == 0
        standard_lattice2.draw('standard_lattice2.svg',
                               title='standard lattice2')

        modified_lattice2 = k2.intersect(modified_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        modified_lattice2 = k2.connect(modified_lattice2)
        modified_lattice2.draw('modified_lattice2.svg',
                               title='modified lattice2')
Exemple #28
0
    def test_case4(self):
        for device in self.devices:
            # put case3, case2 and case1 into a batch
            torch_activation_1 = torch.tensor(
                [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True)

            torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to(
                torch.float32).to(device).requires_grad_(True)

            torch_activation_3 = torch.tensor([
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]).to(device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)
            k2_activation_3 = torch_activation_3.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_3, torch_activation_2, torch_activation_1],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_3, k2_activation_2, k2_activation_1],
                batch_first=True,
                padding_value=0)

            # [[b,c], [c,c], [a]]
            targets = torch.tensor([2, 3, 3, 3, 1]).to(device)
            input_lengths = torch.tensor([3, 3, 1]).to(device)
            target_lengths = torch.tensor([2, 2, 1]).to(device)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)  # (T, N, C)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(
                torch_loss,
                torch.tensor([4.938850402832, 7.355742931366,
                              1.6094379425049]).to(device))

            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            supervision_segments = torch.tensor(
                [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            # [ [b, c], [c, c], [a]]
            linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            scale = torch.tensor([1., -2, 3.5]).to(device)
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad)
            assert torch.allclose(torch_activation_3.grad,
                                  k2_activation_3.grad)
Exemple #29
0
def main():
    # load L, G, symbol_table
    lang_dir = 'data/lang_nosp'
    with open(lang_dir + '/L.fst.txt') as f:
        L = k2.Fsa.from_openfst(f.read(), acceptor=False)

    with open(lang_dir + '/G.fsa.txt') as f:
        G = k2.Fsa.from_openfst(f.read(), acceptor=True)

    with open(lang_dir + '/words.txt') as f:
        symbol_table = k2.SymbolTable.from_str(f.read())

    L = k2.arc_sort(L.invert_())
    G = k2.arc_sort(G)
    graph = k2.intersect(L, G)
    graph = k2.arc_sort(graph)

    # load dataset
    feature_dir = 'exp/data1'
    cuts_train = CutSet.from_json(feature_dir +
                                  '/cuts_train-clean-100.json.gz')

    cuts_dev = CutSet.from_json(feature_dir + '/cuts_dev-clean.json.gz')

    train = K2SpeechRecognitionIterableDataset(cuts_train, shuffle=True)
    validate = K2SpeechRecognitionIterableDataset(cuts_dev, shuffle=False)
    train_dl = torch.utils.data.DataLoader(train,
                                           batch_size=None,
                                           num_workers=1)
    valid_dl = torch.utils.data.DataLoader(validate,
                                           batch_size=None,
                                           num_workers=1)

    dir = 'exp'
    setup_logger('{}/log/log-train'.format(dir))

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    device_id = 0
    device = torch.device('cuda', device_id)
    model = Wav2Letter(num_classes=364, input_type='mfcc', num_features=40)
    model.to(device)

    learning_rate = 0.001
    start_epoch = 0
    num_epochs = 10
    best_objf = 100000
    best_epoch = start_epoch
    best_model_path = os.path.join(dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(dir, 'best-epoch-info')

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=5e-4)
    # optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

    for epoch in range(start_epoch, num_epochs):
        curr_learning_rate = learning_rate * pow(0.4, epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = curr_learning_rate

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf = train_one_epoch(dataloader=train_dl,
                               valid_dataloader=valid_dl,
                               model=model,
                               device=device,
                               graph=graph,
                               symbols=symbol_table,
                               optimizer=optimizer,
                               current_epoch=epoch,
                               num_epochs=num_epochs)
        if objf < best_objf:
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=best_objf,
                               best_objf=best_objf,
                               best_epoch=best_epoch)

        # we always save the model for every epoch
        model_path = os.path.join(dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf)
        epoch_info_filename = os.path.join(dir, 'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           best_epoch=best_epoch)

    logging.warning('Done')
Exemple #30
0
    def test_random_case2(self):
        # 2 sequences
        for device in self.devices:
            T1 = torch.randint(10, 200, (1, )).item()
            T2 = torch.randint(9, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            if T1 < T2:
                T1, T2 = T2, T1

            torch_activation_1 = torch.rand((T1, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)
            torch_activation_2 = torch.rand((T2, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_1, torch_activation_2],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_1, k2_activation_2],
                batch_first=True,
                padding_value=0)

            target_length1 = torch.randint(1, T1, (1, )).item()
            target_length2 = torch.randint(1, T2, (1, )).item()

            target_lengths = torch.tensor([target_length1,
                                           target_length2]).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.sum(), )).to(device)

            # [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)
            input_lengths = torch.tensor([T1, T2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert T1 >= T2
            supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]],
                                                dtype=torch.int32)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([
                targets[:target_length1].tolist(),
                targets[target_length1:].tolist()
            ])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)
            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)
            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad,
                                  atol=1e-2)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad,
                                  atol=1e-2)