예제 #1
0
    def test_rnnt_loss_gradient(self):
        if self.has_torch_rnnt_loss:
            import torchaudio.functional

            B = 5
            S = 20
            T = 300
            C = 100
            frames = torch.randint(S, T, (B,))
            seq_length = torch.randint(3, S - 1, (B,))
            T = torch.max(frames)
            S = torch.max(seq_length)

            am_ = torch.randn((B, T, C), dtype=torch.float32)
            lm_ = torch.randn((B, S + 1, C), dtype=torch.float32)
            symbols_ = torch.randint(0, C - 1, (B, S))
            termination_symbol = C - 1

            boundary_ = torch.zeros((B, 4), dtype=torch.int64)
            boundary_[:, 2] = seq_length
            boundary_[:, 3] = frames

            for device in self.devices:

                # lm: [B][S+1][C]
                lm = lm_.to(device)
                # am: [B][T][C]
                am = am_.to(device)
                symbols = symbols_.to(device)
                boundary = boundary_.to(device)

                logprobs = am.unsqueeze(2) + lm.unsqueeze(1)
                logprobs.requires_grad_()
                k2_loss = k2.rnnt_loss(
                    logits=logprobs,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    boundary=boundary,
                )
                k2_grad = torch.autograd.grad(k2_loss, logprobs)
                k2_grad = k2_grad[0]

                logprobs2 = logprobs.detach().clone().float()
                logprobs2.requires_grad_()
                torch_loss = torchaudio.functional.rnnt_loss(
                    logits=logprobs2,
                    targets=symbols.int(),
                    logit_lengths=boundary[:, 3].int(),
                    target_lengths=boundary[:, 2].int(),
                    blank=termination_symbol,
                )
                torch_grad = torch.autograd.grad(torch_loss, logprobs2)
                torch_grad = torch_grad[0]

                assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2)

                assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2)
예제 #2
0
    def test_rnnt_loss_basic(self):
        B = 1
        S = 3
        T = 4
        # C = 3
        for device in self.devices:
            # lm: [B][S+1][C]
            lm = torch.tensor(
                [[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]],
                dtype=torch.float,
                device=device,
            )
            # am: [B][T][C]
            am = torch.tensor(
                [[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]],
                dtype=torch.float,
                device=device,
            )
            termination_symbol = 2
            symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device)

            px, py = k2.get_rnnt_logprobs(
                lm=lm,
                am=am,
                symbols=symbols,
                termination_symbol=termination_symbol,
            )
            assert px.shape == (B, S, T + 1)
            assert py.shape == (B, S + 1, T)
            assert symbols.shape == (B, S)
            m = k2.mutual_information_recursion(px=px, py=py, boundary=None)

            if device == torch.device("cpu"):
                expected = -m
            assert torch.allclose(-m, expected.to(device))

            # test rnnt_loss_simple
            m = k2.rnnt_loss_simple(
                lm=lm,
                am=am,
                symbols=symbols,
                termination_symbol=termination_symbol,
                boundary=None,
                reduction="none",
            )
            assert torch.allclose(m, expected.to(device))

            # test rnnt_loss_smoothed
            m = k2.rnnt_loss_smoothed(
                lm=lm,
                am=am,
                symbols=symbols,
                termination_symbol=termination_symbol,
                lm_only_scale=0.0,
                am_only_scale=0.0,
                boundary=None,
                reduction="none",
            )
            assert torch.allclose(m, expected.to(device))

            probs = am.unsqueeze(2) + lm.unsqueeze(1)

            # test rnnt_loss
            m = k2.rnnt_loss(
                logits=probs,
                symbols=symbols,
                termination_symbol=termination_symbol,
                boundary=None,
                reduction="none",
            )
            assert torch.allclose(m, expected.to(device))

            # compare with torchaudio rnnt_loss
            if self.has_torch_rnnt_loss:
                import torchaudio.functional

                m = torchaudio.functional.rnnt_loss(
                    logits=probs,
                    targets=symbols.int(),
                    logit_lengths=torch.tensor(
                        [T] * B, dtype=torch.int32, device=device
                    ),
                    target_lengths=torch.tensor(
                        [S] * B, dtype=torch.int32, device=device
                    ),
                    blank=termination_symbol,
                    reduction="none",
                )
                assert torch.allclose(m, expected.to(device))

            # should be invariant to adding a constant for any frame.
            lm += torch.randn(B, S + 1, 1, device=device)
            am += torch.randn(B, T, 1, device=device)

            m = k2.rnnt_loss_simple(
                lm=lm,
                am=am,
                symbols=symbols,
                termination_symbol=termination_symbol,
                boundary=None,
                reduction="none",
            )
            assert torch.allclose(m, expected.to(device))

            m = k2.rnnt_loss_smoothed(
                lm=lm,
                am=am,
                symbols=symbols,
                termination_symbol=termination_symbol,
                lm_only_scale=0.0,
                am_only_scale=0.0,
                boundary=None,
                reduction="none",
            )
            assert torch.allclose(m, expected.to(device))

            probs = am.unsqueeze(2) + lm.unsqueeze(1)
            m = k2.rnnt_loss(
                logits=probs,
                symbols=symbols,
                termination_symbol=termination_symbol,
                boundary=None,
                reduction="none",
            )
            assert torch.allclose(m, expected.to(device))
예제 #3
0
    def test_rnnt_loss_pruned(self):
        B = 4
        T = 300
        S = 50
        C = 10

        frames = torch.randint(S, T, (B,))
        seq_length = torch.randint(3, S - 1, (B,))
        T = torch.max(frames)
        S = torch.max(seq_length)

        am_ = torch.randn((B, T, C), dtype=torch.float64)
        lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
        symbols_ = torch.randint(0, C - 1, (B, S))
        terminal_symbol = C - 1

        boundary_ = torch.zeros((B, 4), dtype=torch.int64)
        boundary_[:, 2] = seq_length
        boundary_[:, 3] = frames

        for modified in [True, False]:
            for device in self.devices:
                # normal rnnt
                am = am_.to(device)
                lm = lm_.to(device)
                symbols = symbols_.to(device)
                boundary = boundary_.to(device)
                t_am = am.unsqueeze(2).float()
                t_lm = lm.unsqueeze(1).float()
                t_prob = t_am + t_lm

                # nonlinear transform
                t_prob = torch.sigmoid(t_prob)
                k2_loss = k2.rnnt_loss(
                    logits=t_prob,
                    symbols=symbols,
                    termination_symbol=terminal_symbol,
                    boundary=boundary,
                    modified=modified,
                )

                print(
                    f"unpruned rnnt loss with modified {modified} : {k2_loss}"
                )

                # pruning
                k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
                    lm=lm,
                    am=am,
                    symbols=symbols,
                    termination_symbol=terminal_symbol,
                    boundary=boundary,
                    modified=modified,
                    return_grad=True,
                    reduction="none",
                )

                for r in range(2, 50, 5):
                    ranges = k2.get_rnnt_prune_ranges(
                        px_grad=px_grad,
                        py_grad=py_grad,
                        boundary=boundary,
                        s_range=r,
                    )
                    # (B, T, r, C)
                    am_p, lm_p = k2.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)

                    t_prob_p = am_p + lm_p

                    # nonlinear transform
                    t_prob_p = torch.sigmoid(t_prob_p)

                    pruned_loss = k2.rnnt_loss_pruned(
                        logits=t_prob_p,
                        symbols=symbols,
                        ranges=ranges,
                        termination_symbol=terminal_symbol,
                        boundary=boundary,
                        modified=modified,
                        reduction="none",
                    )
                    print(f"pruning loss with range {r} : {pruned_loss}")
예제 #4
0
    def test_rnnt_loss_random(self):
        B = 5
        S = 20
        T = 300
        C = 100
        frames = torch.randint(S, T, (B,))
        seq_length = torch.randint(3, S - 1, (B,))
        T = torch.max(frames)
        S = torch.max(seq_length)

        am_ = torch.randn((B, T, C), dtype=torch.float32)
        lm_ = torch.randn((B, S + 1, C), dtype=torch.float32)
        symbols_ = torch.randint(0, C - 1, (B, S))
        termination_symbol = C - 1

        boundary_ = torch.zeros((B, 4), dtype=torch.int64)
        boundary_[:, 2] = seq_length
        boundary_[:, 3] = frames

        for modified in [True, False]:
            for device in self.devices:
                # lm: [B][S+1][C]
                lm = lm_.to(device)
                # am: [B][T][C]
                am = am_.to(device)
                symbols = symbols_.to(device)
                boundary = boundary_.to(device)

                px, py = k2.get_rnnt_logprobs(
                    lm=lm,
                    am=am,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    boundary=boundary,
                    modified=modified,
                )
                assert px.shape == (B, S, T) if modified else (B, S, T + 1)
                assert py.shape == (B, S + 1, T)
                assert symbols.shape == (B, S)
                m = k2.mutual_information_recursion(
                    px=px, py=py, boundary=boundary
                )

                if device == torch.device("cpu"):
                    expected = -torch.mean(m)
                assert torch.allclose(-torch.mean(m), expected.to(device))

                m = k2.rnnt_loss_simple(
                    lm=lm,
                    am=am,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    boundary=boundary,
                    modified=modified,
                )
                assert torch.allclose(m, expected.to(device))

                m = k2.rnnt_loss_smoothed(
                    lm=lm,
                    am=am,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    lm_only_scale=0.0,
                    am_only_scale=0.0,
                    boundary=boundary,
                    modified=modified,
                )
                assert torch.allclose(m, expected.to(device))

                probs = am.unsqueeze(2) + lm.unsqueeze(1)
                m = k2.rnnt_loss(
                    logits=probs,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    boundary=boundary,
                    modified=modified,
                )
                assert torch.allclose(m, expected.to(device))

                # compare with torchaudio rnnt_loss
                if self.has_torch_rnnt_loss and not modified:
                    import torchaudio.functional

                    m = torchaudio.functional.rnnt_loss(
                        logits=probs,
                        targets=symbols.int(),
                        logit_lengths=boundary[:, 3].int(),
                        target_lengths=boundary[:, 2].int(),
                        blank=termination_symbol,
                    )
                    assert torch.allclose(m, expected.to(device))

                # should be invariant to adding a constant for any frame.
                lm += torch.randn(B, S + 1, 1, device=device)
                am += torch.randn(B, T, 1, device=device)

                m = k2.rnnt_loss_simple(
                    lm=lm,
                    am=am,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    boundary=boundary,
                    modified=modified,
                )
                assert torch.allclose(m, expected.to(device))

                probs = am.unsqueeze(2) + lm.unsqueeze(1)
                m = k2.rnnt_loss(
                    logits=probs,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    boundary=boundary,
                    modified=modified,
                )
                assert torch.allclose(m, expected.to(device))

                m = k2.rnnt_loss_smoothed(
                    lm=lm,
                    am=am,
                    symbols=symbols,
                    termination_symbol=termination_symbol,
                    lm_only_scale=0.0,
                    am_only_scale=0.0,
                    boundary=boundary,
                    modified=modified,
                )
                assert torch.allclose(m, expected.to(device))