Esempio n. 1
0
def _compute_mmi_loss_exact_non_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    See :func:`_compute_mmi_loss_exact_optimized` for the meaning
    of the arguments.

    It's more readable, though it invokes k2.intersect_dense twice.

    Note:
      It uses less memory at the cost of speed. It is slower.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
    den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0)

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Esempio n. 2
0
    def forward(self, log_probs: torch.Tensor, targets: torch.Tensor,
                input_lengths: torch.Tensor,
                target_lengths: torch.Tensor) -> torch.Tensor:

        log_probs = log_probs.permute(1, 0, 2).cpu(
        )  # now log_probs is [N, T, C]  batchSize x seqLength x alphabet_size
        supervision_segments = torch.stack(
            (torch.tensor(range(input_lengths.shape[0])),
             torch.zeros(input_lengths.shape[0]), input_lengths),
            1).to(torch.int32)
        indices = torch.argsort(supervision_segments[:, 2], descending=True)
        supervision_segments = supervision_segments[indices]

        dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments)
        decoding_graph = self.graph_compiler.compile(targets.cpu(),
                                                     target_lengths)
        decoding_graph = k2.index(decoding_graph,
                                  indices.to(torch.int32)).to(log_probs.device)

        target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)
        tot_scores = k2.get_tot_scores(target_graph,
                                       log_semiring=True,
                                       use_double_scores=True)
        (tot_score, tot_frames,
         all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                                   supervision_segments[:, 2])
        return -tot_score
Esempio n. 3
0
    def test_two_fsas(self):
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        fsa1 = k2.Fsa.from_str(s1)
        fsa2 = k2.Fsa.from_str(s2)

        fsa1.requires_grad_(True)
        fsa2.requires_grad_(True)

        fsa_vec = k2.create_fsa_vec([fsa1, fsa2])

        log_prob = torch.tensor(
            [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
             [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
            dtype=torch.float32,
            requires_grad=True)

        supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]],
                                            dtype=torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
        out_fsa = k2.intersect_dense(fsa_vec,
                                     dense_fsa_vec,
                                     output_beam=100000)
        assert out_fsa.shape == (2, None, None), 'There should be two FSAs!'

        scores = k2.get_tot_scores(out_fsa,
                                   log_semiring=False,
                                   use_float_scores=True)
        scores.sum().backward()

        # `expected` results are computed using gtn.
        # See https://bit.ly/3oYObeb
        #  expected_scores_out_fsa = torch.tensor(
        #      [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0])

        expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0, 1.0])
        expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0])
        print("fsa2 is ", fsa2.__str__())
        # TODO(dan):: fix this..
        #  expected_grad_log_prob = torch.tensor([
        #      0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0,
        #      0.0, 0.0, 0.0, 1.0
        #  ]).reshape_as(log_prob)

        # assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
        assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad)
        assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad)
Esempio n. 4
0
    def _intersect_calc_scores_mmi_pruned(
        self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True,
    ):
        device = dense_fsa_vec.device
        assert device == num_graphs.device and device == den_graph.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        num_lats = k2.intersect_dense(
            a_fsas=num_graphs,
            b_fsas=dense_fsa_vec,
            output_beam=self.intersect_conf.output_beam,
            seqframe_idx_name="seqframe_idx" if return_lats else None,
        )
        den_lats = k2.intersect_dense_pruned(
            a_fsas=den_graph,
            b_fsas=dense_fsa_vec,
            search_beam=self.intersect_conf.search_beam,
            output_beam=self.intersect_conf.output_beam,
            min_active_states=self.intersect_conf.min_active_states,
            max_active_states=self.intersect_conf.max_active_states,
            seqframe_idx_name="seqframe_idx" if return_lats else None,
        )

        # use_double_scores=True does matter
        # since otherwise it sometimes makes rounding errors
        num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
        den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)

        if return_lats:
            return num_tot_scores, den_tot_scores, num_lats, den_lats
        else:
            return num_tot_scores, den_tot_scores, None, None
Esempio n. 5
0
    def test_case1(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # suppose we have four symbols: <blk>, a, b, c, d
            torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2,
                                             0.2]).to(device)
            k2_activation = torch_activation.detach().clone()

            # (T, N, C)
            torch_activation = torch_activation.reshape(
                1, 1, -1).requires_grad_(True)

            # (N, T, C)
            k2_activation = k2_activation.reshape(1, 1,
                                                  -1).requires_grad_(True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)

            # we have only one sequence and its label is `a`
            targets = torch.tensor([1]).to(device)
            input_lengths = torch.tensor([1]).to(device)
            target_lengths = torch.tensor([1]).to(device)
            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(torch_loss,
                                  torch.tensor([1.6094379425049]).to(device))

            # (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)

            supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([1])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Esempio n. 6
0
    def test_simple(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''
        fsa = k2.Fsa.from_str(s)
        fsa.requires_grad_(True)
        fsa_vec = k2.create_fsa_vec([fsa])
        log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]],
                                dtype=torch.float32,
                                requires_grad=True)

        supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32)
        dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
        out_fsa = k2.intersect_dense(fsa_vec,
                                     dense_fsa_vec,
                                     output_beam=100000)
        scores = k2.get_tot_scores(out_fsa,
                                   log_semiring=False,
                                   use_float_scores=True)
        scores.sum().backward()

        # `expected` results are computed using gtn.
        # See https://bit.ly/3oYObeb
        expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0])
        expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0])
        expected_grad_log_prob = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0,
                                               1.0]).reshape_as(log_prob)
        assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
        assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
        assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Esempio n. 7
0
    def test_case3(self):
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda'))

        for device in devices:
            # (T, N, C)
            torch_activation = torch.tensor([[
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]]).permute(1, 0, 2).to(device).requires_grad_(True)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `b,c`
            targets = torch.tensor([2, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([2, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([4.938850402832]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Esempio n. 8
0
    def _intersect_calc_scores_mmi_exact(
        self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True,
    ):
        device = dense_fsa_vec.device
        assert device == num_graphs.device and device == den_graph.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        den_graph = den_graph.clone()
        num_graphs = num_graphs.clone()

        num_den_graphs = k2.cat([num_graphs, den_graph])

        # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
        # so the following reorders num_den_graphs.

        # [0, 1, 2, ... ]
        num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

        # [num_fsas, num_fsas, num_fsas, ... ]
        den_graph_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

        # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
        num_den_graphs_indexes = torch.stack([num_graphs_indexes, den_graph_indexes]).t().reshape(-1).to(device)

        num_den_reordered_graphs = k2.index_fsa(num_den_graphs, num_den_graphs_indexes)

        # [[0, 1, 2, ...]]
        a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

        # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
        a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

        num_den_lats = k2.intersect_dense(
            a_fsas=num_den_reordered_graphs,
            b_fsas=dense_fsa_vec,
            output_beam=self.intersect_conf.output_beam,
            a_to_b_map=a_to_b_map,
            seqframe_idx_name="seqframe_idx" if return_lats else None,
        )

        num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True)
        num_tot_scores = num_den_tot_scores[::2]
        den_tot_scores = num_den_tot_scores[1::2]

        if return_lats:
            lat_slice = torch.arange(num_fsas, dtype=torch.int32).to(device) * 2
            return (
                num_tot_scores,
                den_tot_scores,
                k2.index_fsa(num_den_lats, lat_slice),
                k2.index_fsa(num_den_lats, lat_slice + 1),
            )
        else:
            return num_tot_scores, den_tot_scores, None, None
Esempio n. 9
0
    def test_random_case1(self):
        # 1 sequence
        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))

        for device in devices:
            T = torch.randint(10, 100, (1,)).item()
            C = torch.randint(20, 30, (1,)).item()
            torch_activation = torch.rand((1, T + 10, C),
                                          dtype=torch.float32,
                                          device=device).requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            # [N, T, C] -> [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation.permute(1, 0, 2), dim=-1)

            input_lengths = torch.tensor([T]).to(device)
            target_lengths = torch.randint(1, T, (1,)).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.item(),)).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')
            k2_log_probs = torch.nn.functional.log_softmax(k2_activation,
                                                           dim=-1)
            supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([targets.tolist()])

            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation.grad,
                                  k2_activation.grad,
                                  atol=1e-2)
Esempio n. 10
0
    def test_two_fsas_long(self):
        # as test_two_fsas, but generate long DenseFsaVec for easier profiling.
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            fsa1 = k2.Fsa.from_str(s1)
            fsa2 = k2.Fsa.from_str(s2)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])
            log_prob = torch.rand((2, 100, 3),
                                  dtype=torch.float32,
                                  device=device,
                                  requires_grad=True)

            supervision_segments = torch.tensor([[0, 1, 95], [1, 20, 50]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            fsa_vec = fsa_vec.to(device)
            out_fsa = k2.intersect_dense(fsa_vec,
                                         dense_fsa_vec,
                                         output_beam=100000,
                                         seqframe_idx_name='seqframe',
                                         frame_idx_name='frame')
            expected_seqframe = torch.arange(96).to(torch.int32).to(device)
            assert torch.allclose(out_fsa.seqframe, expected_seqframe)

            # the second output FSA is empty since there is no self-loop in fsa2
            assert torch.allclose(out_fsa.frame, expected_seqframe)

            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = out_fsa.get_tot_scores(log_semiring=False,
                                            use_double_scores=False)
            scores.sum().backward()
Esempio n. 11
0
    def test_case2(self):
        for device in self.devices:
            # (T, N, C)
            torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute(
                1, 0, 2).to(device)
            torch_activation = torch_activation.to(torch.float32)
            torch_activation.requires_grad_(True)

            k2_activation = torch_activation.detach().clone().requires_grad_(
                True)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activation, dim=-1)  # (T, N, C)
            # we have only one sequence and its labels are `c,c`
            targets = torch.tensor([3, 3]).to(device)
            input_lengths = torch.tensor([3]).to(device)
            target_lengths = torch.tensor([2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            act = k2_activation.permute(1, 0, 2)  # (T, N, C) -> (N, T, C)
            k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1)

            supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            linear_fsa = k2.linear_fsa([3, 3])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            assert torch.allclose(torch_loss,
                                  torch.tensor([7.355742931366]).to(device))

            torch_loss.backward()
            (-k2_scores).backward()
            assert torch.allclose(torch_activation.grad, k2_activation.grad)
Esempio n. 12
0
    def test_simple(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''
        for device in self.devices:
            fsa = k2.Fsa.from_str(s).to(device)
            fsa.requires_grad_(True)
            fsa_vec = k2.create_fsa_vec([fsa])
            log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]],
                                    dtype=torch.float32,
                                    device=device,
                                    requires_grad=True)

            supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            out_fsa = k2.intersect_dense(fsa_vec,
                                         dense_fsa_vec,
                                         output_beam=100000,
                                         seqframe_idx_name='seqframe',
                                         frame_idx_name='frame')
            assert torch.all(
                torch.eq(out_fsa.seqframe,
                         torch.tensor([0, 1, 2], device=device)))

            assert torch.all(
                torch.eq(out_fsa.frame, torch.tensor([0, 1, 2],
                                                     device=device)))

            scores = out_fsa.get_tot_scores(log_semiring=False,
                                            use_double_scores=False)

            scores.sum().backward()

            # `expected` results are computed using gtn.
            # See https://colab.research.google.com/drive/1FzEFjj5GoCDN2d05D9jE682CkR7QIlnm?usp=sharing
            expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0],
                                                   device=device)
            expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0],
                                             device=device)
            expected_grad_log_prob = torch.tensor(
                [0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
                device=device).reshape_as(log_prob)
            assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
            assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
            assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Esempio n. 13
0
    def forward(
            self, nnet_output: torch.Tensor, texts: List,
            supervision_segments: torch.Tensor
    ) -> Tuple[torch.Tensor, int, int]:
        num_graphs = self.graph_compiler.compile(texts).to(nnet_output.device)
        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

        num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, 10.0)

        num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                                 use_double_scores=True)
        tot_scores = num_tot_scores
        tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
            tot_scores, supervision_segments[:, 2])
        return tot_score, tot_frames, all_frames
Esempio n. 14
0
    def test_two_fsas_long(self):
        # as test_two_fsas, but generate long DenseFsaVec for easier profiling.
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        devices = [torch.device('cpu')]
        if torch.cuda.is_available():
            devices.append(torch.device('cuda', 0))
        for device in devices:
            fsa1 = k2.Fsa.from_str(s1)
            fsa2 = k2.Fsa.from_str(s2)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])
            log_prob = torch.rand((2, 500, 3),
                                  dtype=torch.float32,
                                  device=device,
                                  requires_grad=True)

            supervision_segments = torch.tensor([[0, 0, 490], [1, 0, 300]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            fsa_vec = fsa_vec.to(device)
            out_fsa = k2.intersect_dense(fsa_vec,
                                         dense_fsa_vec,
                                         output_beam=100000)
            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = k2.get_tot_scores(out_fsa,
                                       log_semiring=False,
                                       use_float_scores=True)
            scores.sum().backward()
Esempio n. 15
0
    def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor:
        """
        Perform forced alignment and return a tensor that represents a batch of frame-level alignments:
        >>> alignments = torch.tensor([
        ...     [0, 0, 0, 1, 57, 57, 35, 35, 35, ...],
        ...     [...],
        ...     ...
        ... ])

        :return: an int32 tensor with shape ``(batch_size, num_frames)``.
        """
        # Extract feats
        # (batch, seq_len, num_feats)
        if isinstance(cuts, (Cut, MixedCut)):
            cuts = CutSet.from_cuts([cuts])
        assert cuts[
            0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}'

        cuts = cuts.map_supervisions(self.normalize_text)

        otf = OnTheFlyFeatures(self.extractor)
        feats, _ = otf(cuts)
        feats = feats.permute(0, 2, 1)
        texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts]

        # Compute AM posteriors
        # (batch, seq_len ~/ 4, num_phones)
        posteriors, _, _ = self.model(feats)
        # Note: we are using "dummy" supervisions so that the aligner also considers
        # the padding area. We can adjust that behaviour if needed by passing actual
        # supervision segments, but then we will have a ragged tensor (will need to
        # pad the alignments themselves).
        sups = self.dummy_supervisions(feats)
        posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups)

        # Intersection with ground truth transcript graphs
        num, den = self.compiler.compile(texts, self.P)
        alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0)
        best_path = k2.shortest_path(alignment, use_double_scores=True)

        # Retrieve sequences of phone IDs per frame
        # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels)
        frame_labels = torch.stack(
            [best_path[i].labels[:-1] for i in range(best_path.shape[0])])
        return frame_labels
Esempio n. 16
0
def _compute_mmi_loss_pruned(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    See :func:`_compute_mmi_loss_exact_optimized` for the meaning
    of the arguments.

    `pruned` means it uses k2.intersect_dense_pruned

    Note:
      It uses the least amount of memory, but the loss is not exact due
      to pruning.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)
    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)

    # the values for search_beam/output_beam/min_active_states/max_active_states
    # are not tuned. You may want to tune them.
    den_lats = k2.intersect_dense_pruned(den_graphs,
                                         dense_fsa_vec,
                                         search_beam=20.0,
                                         output_beam=7.0,
                                         min_active_states=30,
                                         max_active_states=10000)

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Esempio n. 17
0
    def forward(
        self,
        log_probs: torch.Tensor,
        targets: torch.Tensor,
        input_lengths: torch.Tensor,
        target_lengths: torch.Tensor,
    ) -> torch.Tensor:
        if self.blank != 0:
            # rearrange log_probs to put blank at the first place
            # and shift targets to emulate blank = 0
            log_probs, targets = make_blank_first(self.blank, log_probs,
                                                  targets)
        supervisions, order = create_supervision(input_lengths)
        order = order.long()
        targets = targets[order]
        target_lengths = target_lengths[order]
        # PyTorch is doing the log-softmax normalization as part of the CTC computation.
        # More: https://github.com/k2-fsa/k2/issues/575
        log_probs = GradExpNormalize.apply(
            log_probs, input_lengths,
            "mean" if self.reduction != "sum" else "none")

        if log_probs.device != self.graph_compiler.device:
            self.graph_compiler.to(log_probs.device)
        num_graphs = self.graph_compiler.compile(
            targets + 1 if self.pad_fsavec else targets, target_lengths)

        dense_fsa_vec = (prep_padded_densefsavec(log_probs, supervisions)
                         if self.pad_fsavec else k2.DenseFsaVec(
                             log_probs, supervisions))

        num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec,
                                      torch.finfo(torch.float32).max)

        # use_double_scores=True does matter
        # since otherwise it sometimes makes rounding errors
        num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                                 use_double_scores=True)
        tot_scores = num_tot_scores
        tot_scores, valid_mask = get_tot_objf_and_finite_mask(
            tot_scores, self.reduction)
        return -tot_scores[valid_mask], valid_mask
Esempio n. 18
0
    def test_two_dense(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        for use_map in [True, False]:

            fsa = k2.Fsa.from_str(s)
            fsa.requires_grad_(True)
            fsa_vec = k2.create_fsa_vec([fsa, fsa])
            log_prob = torch.tensor(
                [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
                 [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
                dtype=torch.float32,
                requires_grad=True)

            if use_map:
                a_to_b_map = torch.tensor([0, 0], dtype=torch.int32)
            else:
                a_to_b_map = None

            supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            out_fsa = k2.intersect_dense(fsa_vec,
                                         dense_fsa_vec,
                                         output_beam=100000,
                                         a_to_b_map=a_to_b_map,
                                         seqframe_idx_name='seqframe',
                                         frame_idx_name='frame')

            if not use_map:
                assert torch.allclose(
                    out_fsa.seqframe,
                    torch.tensor([0, 1, 2, 3, 4, 5, 6], dtype=torch.int32))

                assert torch.allclose(
                    out_fsa.frame,
                    torch.tensor([0, 1, 2, 3, 0, 1, 2], dtype=torch.int32))
            else:
                assert torch.allclose(
                    out_fsa.seqframe,
                    torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32))

                assert torch.allclose(
                    out_fsa.frame,
                    torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32))

            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = out_fsa.get_tot_scores(log_semiring=False,
                                            use_double_scores=False)
            scores.sum().backward()

            # `expected` results are computed using gtn.
            # See https://bit.ly/3oYObeb
            #  expected_scores_out_fsa = torch.tensor(
            #      [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0])

            if not use_map:
                expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0])
            else:
                expected_grad_fsa = torch.tensor([2.0, 2.0, 2.0, 2.0])

            #  expected_grad_log_prob = torch.tensor([
            #      0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0,
            #      0.0, 0.0, 0.0, 1.0
            #  ]).reshape_as(log_prob)

            assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
Esempio n. 19
0
def _compute_mmi_loss_exact_optimized(
        nnet_output: torch.Tensor,
        texts: List[str],
        supervision_segments: torch.Tensor,
        graph_compiler: MmiTrainingGraphCompiler,
        P: k2.Fsa,
        den_scale: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    '''
    The function name contains `exact`, which means it uses a version of
    intersection without pruning.

    `optimized` in the function name means this function is optimized
    in that it calls k2.intersect_dense only once

    Note:
      It is faster at the cost of using more memory.

    Args:
      nnet_output:
        A 3-D tensor of shape [N, T, C]
      texts:
        The transcript. Each element consists of space(s) separated words.
      supervision_segments:
        A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`.
      graph_compiler:
        Used to build num_graphs and den_graphs
      P:
        Represents a bigram Fsa.
      den_scale:
        The scale applied to the denominator tot_scores.
    '''
    num_graphs, den_graphs = graph_compiler.compile(texts,
                                                    P,
                                                    replicate_den=False)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

    device = num_graphs.device

    num_fsas = num_graphs.shape[0]
    assert dense_fsa_vec.dim0() == num_fsas

    assert den_graphs.shape[0] == 1

    # the aux_labels of num_graphs is k2.RaggedInt
    # but it is torch.Tensor for den_graphs.
    #
    # The following converts den_graphs.aux_labels
    # from torch.Tensor to k2.RaggedInt so that
    # we can use k2.append() later
    den_graphs.convert_attr_to_ragged_(name='aux_labels')

    # The motivation to concatenate num_graphs and den_graphs
    # is to reduce the number of calls to k2.intersect_dense.
    num_den_graphs = k2.cat([num_graphs, den_graphs])

    # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
    # so the following reorders num_den_graphs.
    #
    # The following code computes a_to_b_map

    # [0, 1, 2, ... ]
    num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

    # [num_fsas, num_fsas, num_fsas, ... ]
    den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)

    # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
    num_den_graphs_indexes = torch.stack(
        [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device)

    num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)

    # [[0, 1, 2, ...]]
    a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

    # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
    a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

    num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                      dense_fsa_vec,
                                      output_beam=10.0,
                                      a_to_b_map=a_to_b_map)

    num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True,
                                                     use_double_scores=True)

    num_tot_scores = num_den_tot_scores[::2]
    den_tot_scores = num_den_tot_scores[1::2]

    tot_scores = num_tot_scores - den_scale * den_tot_scores
    tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
        tot_scores, supervision_segments[:, 2])
    return tot_score, tot_frames, all_frames
Esempio n. 20
0
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: CtcTrainingGraphCompiler,
             training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    decoding_graph = graph_compiler.compile(texts).to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device
    # TODO(haowen): with a small `beam`, we may get empty `target_graph`,
    # thus `tot_scores` will be `inf`. Definitely we need to handle this later.
    target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

    tot_scores = k2.get_tot_scores(target_graph,
                                   log_semiring=True,
                                   use_double_scores=True)

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
def get_objf(batch: Dict,
             model: AcousticModel,
             device: torch.device,
             graph_compiler: CtcTrainingGraphCompiler,
             is_training: bool,
             is_update: bool,
             accum_grad: int = 1,
             att_rate: float = 0.0,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         (((supervisions['start_frame'] - 1) // 2 - 1) // 2),
         (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32)
    supervision_segments = torch.clamp(supervision_segments, min=0)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments)
        if att_rate != 0.0:
            att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler)
    else:
        with torch.no_grad():
            nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments)
            if att_rate != 0.0:
                att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    decoding_graph = graph_compiler.compile(texts).to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device

    target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

    tot_scores = target_graph.get_tot_scores(
        log_semiring=True,
        use_double_scores=True)

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:
        if att_rate != 0.0:
            loss = (- (1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad)
        else:
            loss = (-tot_score) / (len(texts) * accum_grad)
        loss.backward()
        if is_update:
            clip_grad_value_(model.parameters(), 5.0)
            optimizer.step()
            optimizer.zero_grad()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Esempio n. 22
0
    def test_random_case2(self):
        # 2 sequences
        for device in self.devices:
            T1 = torch.randint(10, 200, (1, )).item()
            T2 = torch.randint(9, 100, (1, )).item()
            C = torch.randint(20, 30, (1, )).item()
            if T1 < T2:
                T1, T2 = T2, T1

            torch_activation_1 = torch.rand((T1, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)
            torch_activation_2 = torch.rand((T2, C),
                                            dtype=torch.float32,
                                            device=device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_1, torch_activation_2],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_1, k2_activation_2],
                batch_first=True,
                padding_value=0)

            target_length1 = torch.randint(1, T1, (1, )).item()
            target_length2 = torch.randint(1, T2, (1, )).item()

            target_lengths = torch.tensor([target_length1,
                                           target_length2]).to(device)
            targets = torch.randint(1, C - 1,
                                    (target_lengths.sum(), )).to(device)

            # [T, N, C]
            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)
            input_lengths = torch.tensor([T1, T2]).to(device)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert T1 >= T2
            supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]],
                                                dtype=torch.int32)
            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)
            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo(list(range(C))).invert_())
            linear_fsa = k2.linear_fsa([
                targets[:target_length1].tolist(),
                targets[target_length1:].tolist()
            ])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)
            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)
            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)
            scale = torch.rand_like(torch_loss) * 100
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad,
                                  atol=1e-2)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad,
                                  atol=1e-2)
Esempio n. 23
0
def get_loss(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiMbrTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    assert P.device == device
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num_graph, den_graph, decoding_graph = graph_compiler.compile(
                texts, P)

    assert num_graph.requires_grad == is_training
    assert den_graph.requires_grad is False
    assert decoding_graph.requires_grad is False
    assert len(
        decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None)

    num_graph = num_graph.to(device)
    den_graph = den_graph.to(device)

    decoding_graph = decoding_graph.to(device)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num_lats = k2.intersect_dense(num_graph,
                                  dense_fsa_vec,
                                  10.0,
                                  seqframe_idx_name='seqframe_idx')

    mbr_lats = k2.intersect_dense_pruned(decoding_graph,
                                         dense_fsa_vec,
                                         20.0,
                                         7.0,
                                         30,
                                         10000,
                                         seqframe_idx_name='seqframe_idx')

    if True:
        # WARNING: the else branch is not working at present (the total loss is not stable)
        den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0)
    else:
        # in this case, we can remove den_graph
        den_lats = mbr_lats

    num_tot_scores = num_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    den_tot_scores = den_lats.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    if id(den_lats) == id(mbr_lats):
        # Some entries in den_tot_scores may be -inf.
        # The corresponding sequences are discarded/ignored.
        finite_indexes = torch.isfinite(den_tot_scores)
        den_tot_scores = den_tot_scores[finite_indexes]
        num_tot_scores = num_tot_scores[finite_indexes]
    else:
        finite_indexes = None

    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2],
                                               finite_indexes)

    num_rows = dense_fsa_vec.scores.shape[0]
    num_cols = dense_fsa_vec.scores.shape[1] - 1
    mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx,
                                      cols=num_lats.phones,
                                      values=num_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)

    mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx,
                                      cols=mbr_lats.phones,
                                      values=mbr_lats.get_arc_post(True,
                                                                   True).exp(),
                                      size=(num_rows, num_cols),
                                      min_col_index=0)
    # NOTE: Due to limited support of PyTorch's autograd for sparse tensors,
    # we cannot use (mbr_num_sparse - mbr_den_sparse) here
    #
    # The following works only for torch >= 1.7.0
    mbr_loss = torch.sparse.sum(
        k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce()))

    mmi_loss = -tot_score

    total_loss = mmi_loss + mbr_loss

    if is_training:
        optimizer.zero_grad()
        total_loss.backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = (
        mmi_loss.detach().cpu().item(),
        mbr_loss.detach().cpu().item(),
        tot_frames.cpu().item(),
        all_frames.cpu().item(),
    )
    return ans
Esempio n. 24
0
    def forward(
            self, nnet_output: torch.Tensor, texts: List,
            supervision_segments: torch.Tensor
    ) -> Tuple[torch.Tensor, int, int]:
        num_graphs, den_graphs = self.graph_compiler.compile(
            texts, self.P, replicate_den=False)

        dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)

        device = num_graphs.device

        num_fsas = num_graphs.shape[0]
        assert dense_fsa_vec.dim0() == num_fsas

        assert den_graphs.shape[0] == 1

        # the aux_labels of num_graphs is k2.RaggedInt
        # but it is torch.Tensor for den_graphs.
        #
        # The following converts den_graphs.aux_labels
        # from torch.Tensor to k2.RaggedInt so that
        # we can use k2.append() later
        den_graphs.convert_attr_to_ragged_(name='aux_labels')

        num_den_graphs = k2.cat([num_graphs, den_graphs])

        # NOTE: The a_to_b_map in k2.intersect_dense must be sorted
        # so the following reorders num_den_graphs.

        # [0, 1, 2, ... ]
        num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)

        # [num_fsas, num_fsas, num_fsas, ... ]
        den_graphs_indexes = torch.tensor([num_fsas] * num_fsas,
                                          dtype=torch.int32)

        # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
        num_den_graphs_indexes = torch.stack(
            [num_graphs_indexes,
             den_graphs_indexes]).t().reshape(-1).to(device)

        num_den_reordered_graphs = k2.index(num_den_graphs,
                                            num_den_graphs_indexes)

        # [[0, 1, 2, ...]]
        a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)

        # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
        a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)

        num_den_lats = k2.intersect_dense(num_den_reordered_graphs,
                                          dense_fsa_vec,
                                          output_beam=10.0,
                                          a_to_b_map=a_to_b_map)

        num_den_tot_scores = num_den_lats.get_tot_scores(
            log_semiring=True, use_double_scores=True)

        num_tot_scores = num_den_tot_scores[::2]
        den_tot_scores = num_den_tot_scores[1::2]

        tot_scores = num_tot_scores - self.den_scale * den_tot_scores
        tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames(
            tot_scores, supervision_segments[:, 2])
        return tot_score, tot_frames, all_frames
Esempio n. 25
0
    def test_case4(self):
        for device in self.devices:
            # put case3, case2 and case1 into a batch
            torch_activation_1 = torch.tensor(
                [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True)

            torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to(
                torch.float32).to(device).requires_grad_(True)

            torch_activation_3 = torch.tensor([
                [-5, -4, -3, -2, -1],
                [-10, -9, -8, -7, -6],
                [-15, -14, -13, -12, -11.],
            ]).to(device).requires_grad_(True)

            k2_activation_1 = torch_activation_1.detach().clone(
            ).requires_grad_(True)
            k2_activation_2 = torch_activation_2.detach().clone(
            ).requires_grad_(True)
            k2_activation_3 = torch_activation_3.detach().clone(
            ).requires_grad_(True)

            # [T, N, C]
            torch_activations = torch.nn.utils.rnn.pad_sequence(
                [torch_activation_3, torch_activation_2, torch_activation_1],
                batch_first=False,
                padding_value=0)

            # [N, T, C]
            k2_activations = torch.nn.utils.rnn.pad_sequence(
                [k2_activation_3, k2_activation_2, k2_activation_1],
                batch_first=True,
                padding_value=0)

            # [[b,c], [c,c], [a]]
            targets = torch.tensor([2, 3, 3, 3, 1]).to(device)
            input_lengths = torch.tensor([3, 3, 1]).to(device)
            target_lengths = torch.tensor([2, 2, 1]).to(device)

            torch_log_probs = torch.nn.functional.log_softmax(
                torch_activations, dim=-1)  # (T, N, C)

            torch_loss = torch.nn.functional.ctc_loss(
                log_probs=torch_log_probs,
                targets=targets,
                input_lengths=input_lengths,
                target_lengths=target_lengths,
                reduction='none')

            assert torch.allclose(
                torch_loss,
                torch.tensor([4.938850402832, 7.355742931366,
                              1.6094379425049]).to(device))

            k2_log_probs = torch.nn.functional.log_softmax(k2_activations,
                                                           dim=-1)
            supervision_segments = torch.tensor(
                [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(k2_log_probs,
                                           supervision_segments).to(device)

            ctc_topo_inv = k2.arc_sort(
                build_ctc_topo([0, 1, 2, 3, 4]).invert_())
            # [ [b, c], [c, c], [a]]
            linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]])
            decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa)
            decoding_graph = k2.connect(decoding_graph).invert_().to(device)

            target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec,
                                              100.0)

            k2_scores = target_graph.get_tot_scores(log_semiring=True,
                                                    use_double_scores=False)
            assert torch.allclose(torch_loss, -1 * k2_scores)

            scale = torch.tensor([1., -2, 3.5]).to(device)
            (torch_loss * scale).sum().backward()
            (-k2_scores * scale).sum().backward()
            assert torch.allclose(torch_activation_1.grad,
                                  k2_activation_1.grad)
            assert torch.allclose(torch_activation_2.grad,
                                  k2_activation_2.grad)
            assert torch.allclose(torch_activation_3.grad,
                                  k2_activation_3.grad)
Esempio n. 26
0
    def test_two_dense(self):
        s = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        for device in self.devices:
            for use_map in [True, False]:
                fsa = k2.Fsa.from_str(s).to(device)
                fsa.requires_grad_(True)
                fsa_vec = k2.create_fsa_vec([fsa, fsa])
                log_prob = torch.tensor(
                    [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
                     [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
                    dtype=torch.float32,
                    device=device,
                    requires_grad=True)

                if use_map:
                    a_to_b_map = torch.tensor([0, 0],
                                              dtype=torch.int32,
                                              device=device)
                else:
                    a_to_b_map = None

                supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]],
                                                    dtype=torch.int32)
                dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
                out_fsa = k2.intersect_dense(fsa_vec,
                                             dense_fsa_vec,
                                             output_beam=100000,
                                             a_to_b_map=a_to_b_map,
                                             seqframe_idx_name='seqframe',
                                             frame_idx_name='frame')

                if not use_map:
                    assert torch.all(
                        torch.eq(
                            out_fsa.seqframe,
                            torch.tensor([0, 1, 2, 3, 4, 5, 6],
                                         device=device)))

                    assert torch.all(
                        torch.eq(
                            out_fsa.frame,
                            torch.tensor([0, 1, 2, 3, 0, 1, 2],
                                         device=device)))
                else:
                    assert torch.all(
                        torch.eq(
                            out_fsa.seqframe,
                            torch.tensor([0, 1, 2, 3, 0, 1, 2, 3],
                                         device=device)))

                    assert torch.all(
                        torch.eq(
                            out_fsa.frame,
                            torch.tensor([0, 1, 2, 3, 0, 1, 2, 3],
                                         device=device)))

                assert out_fsa.shape == (2, None,
                                         None), 'There should be two FSAs!'

                scores = out_fsa.get_tot_scores(log_semiring=False,
                                                use_double_scores=False)
                scores.sum().backward()

                # `expected` results are computed using gtn.
                # See https://colab.research.google.com/drive/1FzEFjj5GoCDN2d05D9jE682CkR7QIlnm?usp=sharing
                if not use_map:
                    expected_scores_out_fsa = torch.tensor(
                        [1.2, 50.05, 2.0, 3.0, 1.2, 2.6, 3.0], device=device)
                else:
                    expected_scores_out_fsa = torch.tensor(
                        [1.2, 50.05, 2.0, 3.0, 1.2, 50.05, 2.0, 3.0],
                        device=device)
                assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)

                if not use_map:
                    expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0],
                                                     device=device)
                else:
                    expected_grad_fsa = torch.tensor([2.0, 2.0, 2.0, 2.0],
                                                     device=device)

                assert torch.allclose(expected_grad_fsa, fsa.scores.grad)

                if not use_map:
                    expected_grad_log_prob = torch.tensor(
                        [
                            0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0,
                            1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0
                        ],
                        device=device).reshape_as(log_prob)
                else:
                    expected_grad_log_prob = torch.tensor(
                        [
                            0.0,
                            2.0,
                            0.0,
                            0.0,
                            2.0,
                            0.0,
                            0.0,
                            0.0,
                            2.0,
                            0.0,
                            0.0,
                            0.0,
                            0.0,
                            0.0,
                            0.0,
                            0.0,
                            0.0,
                            0.0,
                        ],
                        device=device).reshape_as(log_prob)
                assert torch.allclose(expected_grad_log_prob, log_prob.grad)
Esempio n. 27
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'],
                            model.subsampling_factor),
         torch.floor_divide(supervisions['num_frames'],
                            model.subsampling_factor)), 1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans
Esempio n. 28
0
    def test_two_fsas(self):
        s1 = '''
            0 1 1 1.0
            1 1 1 50.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''

        s2 = '''
            0 1 1 1.0
            1 2 2 2.0
            2 3 -1 3.0
            3
        '''
        for device in self.devices:
            fsa1 = k2.Fsa.from_str(s1).to(device)
            fsa2 = k2.Fsa.from_str(s2).to(device)

            fsa1.requires_grad_(True)
            fsa2.requires_grad_(True)

            fsa_vec = k2.create_fsa_vec([fsa1, fsa2])

            log_prob = torch.tensor(
                [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]],
                 [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]],
                dtype=torch.float32,
                device=device,
                requires_grad=True)

            supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]],
                                                dtype=torch.int32)
            dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments)
            out_fsa = k2.intersect_dense(fsa_vec,
                                         dense_fsa_vec,
                                         output_beam=100000,
                                         seqframe_idx_name='seqframe',
                                         frame_idx_name='frame')
            assert torch.all(
                torch.eq(out_fsa.seqframe,
                         torch.tensor([0, 1, 2, 3, 4, 5, 6], device=device)))

            assert torch.all(
                torch.eq(out_fsa.frame,
                         torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device)))

            assert out_fsa.shape == (2, None,
                                     None), 'There should be two FSAs!'

            scores = out_fsa.get_tot_scores(log_semiring=False,
                                            use_double_scores=False)
            scores.sum().backward()

            # `expected` results are computed using gtn.
            # See https://bit.ly/3oYObeb
            #  expected_scores_out_fsa = torch.tensor(
            #      [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0])

            expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0, 1.0],
                                              device=device)
            expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0], device=device)
            # TODO(dan):: fix this..
            #  expected_grad_log_prob = torch.tensor([
            #      0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0,
            #      1.0, 0.0, 0.0, 0.0, 1.0
            #  ]).reshape_as(log_prob)

            # assert torch.allclose(out_fsa.scores, expected_scores_out_fsa)
            assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad)
            assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad)
Esempio n. 29
0
def get_objf(
    batch: Dict,
    model: AcousticModel,
    device: torch.device,
    graph_compiler: CtcTrainingGraphCompiler,
    training: bool,
    optimizer: Optional[torch.optim.Optimizer] = None,
):
    feature = batch["inputs"]
    supervisions = batch["supervisions"]
    supervision_segments = torch.stack(
        (
            supervisions["sequence_idx"],
            torch.floor_divide(supervisions["start_frame"],
                               model.subsampling_factor),
            torch.floor_divide(supervisions["num_frames"],
                               model.subsampling_factor),
        ),
        1,
    ).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions["text"]
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    decoding_graph = graph_compiler.compile(texts).to(device)

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert decoding_graph.is_cuda()
    assert decoding_graph.device == device
    assert nnet_output.device == device

    target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

    tot_scores = target_graph.get_tot_scores(log_semiring=True,
                                             use_double_scores=True)

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if training:
        optimizer.zero_grad()
        (-tot_score).backward()
        clip_grad_value_(model.parameters(), 5.0)
        optimizer.step()

    ans = (
        -tot_score.detach().cpu().item(),
        tot_frames.cpu().item(),
        all_frames.cpu().item(),
    )
    return ans
Esempio n. 30
0
def get_objf(batch: Dict,
             model: AcousticModel,
             P: k2.Fsa,
             device: torch.device,
             graph_compiler: MmiTrainingGraphCompiler,
             is_training: bool,
             tb_writer: Optional[SummaryWriter] = None,
             global_batch_idx_train: Optional[int] = None,
             optimizer: Optional[torch.optim.Optimizer] = None):
    feature = batch['features']
    supervisions = batch['supervisions']
    subsampling_factor = model.module.subsampling_factor if isinstance(
        model, DDP) else model.subsampling_factor
    supervision_segments = torch.stack(
        (supervisions['sequence_idx'],
         torch.floor_divide(supervisions['start_frame'], subsampling_factor),
         torch.floor_divide(supervisions['num_frames'], subsampling_factor)),
        1).to(torch.int32)
    indices = torch.argsort(supervision_segments[:, 2], descending=True)
    supervision_segments = supervision_segments[indices]

    texts = supervisions['text']
    texts = [texts[idx] for idx in indices]
    assert feature.ndim == 3
    # print(supervision_segments[:, 1] + supervision_segments[:, 2])

    feature = feature.to(device)
    # at entry, feature is [N, T, C]
    feature = feature.permute(0, 2, 1)  # now feature is [N, C, T]
    if is_training:
        nnet_output = model(feature)
    else:
        with torch.no_grad():
            nnet_output = model(feature)

    # nnet_output is [N, C, T]
    nnet_output = nnet_output.permute(0, 2, 1)  # now nnet_output is [N, T, C]

    if is_training:
        num, den = graph_compiler.compile(texts, P)
    else:
        with torch.no_grad():
            num, den = graph_compiler.compile(texts, P)

    assert num.requires_grad == is_training
    assert den.requires_grad is False
    num = num.to(device)
    den = den.to(device)

    # nnet_output2 = nnet_output.clone()
    # blank_bias = -7.0
    # nnet_output2[:,:,0] += blank_bias

    dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
    assert nnet_output.device == device

    num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
    den = k2.intersect_dense(den, dense_fsa_vec, 10.0)

    num_tot_scores = num.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    den_tot_scores = den.get_tot_scores(log_semiring=True,
                                        use_double_scores=True)
    tot_scores = num_tot_scores - den_scale * den_tot_scores

    (tot_score, tot_frames,
     all_frames) = get_tot_objf_and_num_frames(tot_scores,
                                               supervision_segments[:, 2])

    if is_training:

        def maybe_log_gradients(tag: str):
            if (tb_writer is not None and global_batch_idx_train is not None
                    and global_batch_idx_train % 200 == 0):
                tb_writer.add_scalars(tag,
                                      measure_gradient_norms(model, norm='l1'),
                                      global_step=global_batch_idx_train)

        optimizer.zero_grad()
        (-tot_score).backward()
        maybe_log_gradients('train/grad_norms')
        clip_grad_value_(model.parameters(), 5.0)
        maybe_log_gradients('train/clipped_grad_norms')
        if tb_writer is not None and global_batch_idx_train % 200 == 0:
            # Once in a time we will perform a more costly diagnostic
            # to check the relative parameter change per minibatch.
            deltas = optim_step_and_measure_param_change(model, optimizer)
            tb_writer.add_scalars('train/relative_param_change_per_minibatch',
                                  deltas,
                                  global_step=global_batch_idx_train)
        else:
            optimizer.step()

    ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item(
    ), all_frames.cpu().item()
    return ans