Ejemplo n.º 1
0
    def test_no_repeated(self):
        # standard ctc topo and modified ctc topo
        # should be equivalent if there are no
        # repeated neighboring symbols in the transcript
        max_token = 3
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 3])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        input1 = k2.linear_fsa([1, 1, 1, 0, 0, 2, 2, 3, 3])
        input2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 3, 3])
        inputs = [input1, input2]
        for i in inputs:
            lattice1 = k2.intersect(standard_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice2 = k2.intersect(modified_graph,
                                    i,
                                    treat_epsilons_specially=False)
            lattice1 = k2.connect(lattice1)
            lattice2 = k2.connect(lattice2)

            aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
            aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
            aux_labels1 = aux_labels1[:-1]  # remove -1
            aux_labels2 = aux_labels2[:-1]
            assert torch.all(torch.eq(aux_labels1, aux_labels2))
            assert torch.all(torch.eq(aux_labels2, torch.tensor([1, 2, 3])))
Ejemplo n.º 2
0
    def test_with_repeated(self):
        max_token = 2
        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        transcript = k2.linear_fsa([1, 2, 2])
        standard_graph = k2.compose(standard, transcript)
        modified_graph = k2.compose(modified, transcript)

        # There is a blank separating 2 in the input
        # so standard and modified ctc topo should be equivalent
        input = k2.linear_fsa([1, 1, 2, 2, 0, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)

        aux_labels1 = lattice1.aux_labels[lattice1.aux_labels != 0]
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels1 = aux_labels1[:-1]  # remove -1
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, aux_labels2))
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))

        # There are no blanks separating 2 in the input.
        # The standard ctc topo requires that there must be a blank
        # separating 2, so lattice1 in the following is empty
        input = k2.linear_fsa([1, 1, 2, 2, 0, 0])
        lattice1 = k2.intersect(standard_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice2 = k2.intersect(modified_graph,
                                input,
                                treat_epsilons_specially=False)
        lattice1 = k2.connect(lattice1)
        lattice2 = k2.connect(lattice2)
        assert lattice1.num_arcs == 0

        # Since there are two 2s in the input and there are also two 2s
        # in the transcript, the final output contains only one path.
        # If there were more than two 2s in the input, the output
        # would contain more than one path
        aux_labels2 = lattice2.aux_labels[lattice2.aux_labels != 0]
        aux_labels2 = aux_labels2[:-1]
        assert torch.all(torch.eq(aux_labels1, torch.tensor([1, 2, 2])))
Ejemplo n.º 3
0
    def test_case1(self):
        for device in self.devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)

            torch_loss.backward()
            k2_loss.backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Ejemplo n.º 4
0
    def test_case3(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            expected_loss = torch.tensor([4.938850402832],
                                         device=device) / target_lengths

            assert torch.allclose(torch_loss, k2_loss)
            assert torch.allclose(torch_loss, expected_loss)

            torch_loss.backward()
            k2_loss.backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Ejemplo n.º 5
0
    def test(self):
        for device in self.devices:
            fsa1 = k2.ctc_topo(5, device=device)
            fsa1.attr1 = torch.tensor([1] * fsa1.num_arcs, device=device)

            stream1 = k2.RnntDecodingStream(fsa1)

            fsa2 = k2.trivial_graph(3, device=device)
            fsa2.attr1 = torch.tensor([2] * fsa2.num_arcs, device=device)
            fsa2.attr2 = torch.tensor([22] * fsa2.num_arcs, device=device)

            stream2 = k2.RnntDecodingStream(fsa2)

            fsa3 = k2.ctc_topo(3, modified=True, device=device)
            fsa3.attr3 = k2.RaggedTensor(
                torch.ones((fsa3.num_arcs, 2), dtype=torch.int32, device=device)
                * 3
            )

            stream3 = k2.RnntDecodingStream(fsa3)

            config = k2.RnntDecodingConfig(10, 2, 3.0, 3, 3)
            streams = k2.RnntDecodingStreams(
                [stream1, stream2, stream3], config
            )

            for i in range(5):
                shape, context = streams.get_contexts()
                logprobs = torch.randn(
                    (context.shape[0], 10), dtype=torch.float32, device=device
                )
                streams.advance(logprobs)

            streams.terminate_and_flush_to_streams()
            ofsa = streams.format_output([3, 4, 5])
            print(ofsa)
Ejemplo n.º 6
0
    def test_random_case1(self):
        # 1 sequence
        for device in self.devices:
            T = torch.randint(10, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1, )).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(), )).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo = k2.ctc_topo(C - 1)
            linear_fsa = k2.linear_fsa([targets.tolist()])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (k2_loss * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Ejemplo n.º 7
0
def _visualize_ctc_topo():
    '''See https://git.io/JtqyJ
    for what the resulting ctc_topo looks like.
    '''
    symbols = k2.SymbolTable.from_str('''
        <blk> 0
        a 1
        b 2
    ''')
    aux_symbols = k2.SymbolTable.from_str('''
        a 1
        b 2
    ''')
    ctc_topo = k2.ctc_topo(2)
    ctc_topo.labels_sym = symbols
    ctc_topo.aux_labels_sym = aux_symbols
    ctc_topo.draw('ctc_topo.pdf')
Ejemplo n.º 8
0
    def test_random_case2(self):
        # 2 sequences
        for device in self.devices:
            T1 = torch.randint(10, 200, (1, )).item()
            T2 = torch.randint(9, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            if T1 < T2:
                T1, T2 = T2, T1

            torch_activation_1 = torch.rand((T1, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)
            torch_activation_2 = torch.rand((T2, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_1, torch_activation_2],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_1, k2_activation_2],
                batch_first=True,
                padding_value=0)

            target_length1 = torch.randint(1, T1, (1, )).item()
            target_length2 = torch.randint(1, T2, (1, )).item()

            target_lengths = torch.tensor([target_length1,
                                           target_length2]).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.sum(), )).to(device)

            # [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)
            input_lengths = torch.tensor([T1, T2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='mean')

            assert T1 >= T2
            supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]],
                                                dtype=torch.int32)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo = k2.ctc_topo(C - 1)
            linear_fsa = k2.linear_fsa([
                targets[:target_length1].tolist(),
                targets[target_length1:].tolist()
            ])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='mean',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (k2_loss * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad,
                                  atol=1e-2)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad,
                                  atol=1e-2)
Ejemplo n.º 9
0
    def test_case4(self):
        for device in self.devices:
            # put case3, case2 and case1 into a batch
            torch_activation_1 = torch.tensor(
                [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True)

            torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to(
                torch.float32).to(device).requires_grad_(True)

            torch_activation_3 = torch.tensor([
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]).to(device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)
            k2_activation_3 = torch_activation_3.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_3, torch_activation_2, torch_activation_1],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_3, k2_activation_2, k2_activation_1],
                batch_first=True,
                padding_value=0)

            # [[b,c], [c,c], [a]]
            targets = torch.tensor([2, 3, 3, 3, 1]).to(device)
            input_lengths = torch.tensor([3, 3, 1]).to(device)
            target_lengths = torch.tensor([2, 2, 1]).to(device)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)  # (T, N, C)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='sum')

            expected_loss = torch.tensor(
                [4.938850402832, 7.355742931366, 1.6094379425049]).sum()

            assert torch.allclose(torch_loss, expected_loss.to(device))

            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            supervision_segments = torch.tensor(
                [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo = k2.ctc_topo(4)
            # [ [b, c], [c, c], [a]]
            linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]])
            decoding_graph = k2.compose(ctc_topo, linear_fsa).to(device)

            k2_loss = k2.ctc_loss(decoding_graph,
                                  dense_fsa_vec,
                                  reduction='sum',
                                  target_lengths=target_lengths)

            assert torch.allclose(torch_loss, k2_loss)

            scale = torch.tensor([1., -2, 3.5]).to(device)
            (torch_loss * scale).sum().backward()
            (k2_loss * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad)
            assert torch.allclose(torch_activation_3.grad,
                                  k2_activation_3.grad)
Ejemplo n.º 10
0
    def visualize_ctc_topo():
        '''This function shows how to visualize
        standard/modified ctc topologies. It's for
        demonstration only, not for testing.
        '''
        max_token = 2
        labels_sym = k2.SymbolTable.from_str('''
            <blk> 0
            z 1
            o 2
        ''')
        aux_labels_sym = k2.SymbolTable.from_str('''
            z 1
            o 2
        ''')

        word_sym = k2.SymbolTable.from_str('''
            zoo 1
        ''')

        standard = k2.ctc_topo(max_token, modified=False)
        modified = k2.ctc_topo(max_token, modified=True)
        standard.labels_sym = labels_sym
        standard.aux_labels_sym = aux_labels_sym

        modified.labels_sym = labels_sym
        modified.aux_labels_sym = aux_labels_sym

        standard.draw('standard_topo.svg', title='standard CTC topo')
        modified.draw('modified_topo.svg', title='modified CTC topo')
        fsa = k2.linear_fst([1, 2, 2], [1, 0, 0])
        fsa.labels_sym = labels_sym
        fsa.aux_labels_sym = word_sym
        fsa.draw('transcript.svg', title='transcript')

        standard_graph = k2.compose(standard, fsa)
        modified_graph = k2.compose(modified, fsa)
        standard_graph.draw('standard_graph.svg', title='standard graph')
        modified_graph.draw('modified_graph.svg', title='modified graph')

        # z z <blk> <blk> o o <blk> o <blk>
        inputs = k2.linear_fsa([1, 1, 0, 0, 2, 2, 0, 2, 0])
        inputs.labels_sym = labels_sym
        inputs.draw('inputs.svg', title='inputs')
        standard_lattice = k2.intersect(standard_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        standard_lattice.draw('standard_lattice.svg', title='standard lattice')

        modified_lattice = k2.intersect(modified_graph,
                                        inputs,
                                        treat_epsilons_specially=False)
        modified_lattice = k2.connect(modified_lattice)
        modified_lattice.draw('modified_lattice.svg', title='modified lattice')

        # z z <blk> <blk> o o o <blk>
        inputs2 = k2.linear_fsa([1, 1, 0, 0, 2, 2, 2, 0])
        inputs2.labels_sym = labels_sym
        inputs2.draw('inputs2.svg', title='inputs2')
        standard_lattice2 = k2.intersect(standard_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        standard_lattice2 = k2.connect(standard_lattice2)
        # It's empty since the topo requires that there must be a blank
        # between the two o's in zoo
        assert standard_lattice2.num_arcs == 0
        standard_lattice2.draw('standard_lattice2.svg',
                               title='standard lattice2')

        modified_lattice2 = k2.intersect(modified_graph,
                                         inputs2,
                                         treat_epsilons_specially=False)
        modified_lattice2 = k2.connect(modified_lattice2)
        modified_lattice2.draw('modified_lattice2.svg',
                               title='modified lattice2')