def test_rnnt_loss_gradient(self): if self.has_torch_rnnt_loss: import torchaudio.functional B = 5 S = 20 T = 300 C = 100 frames = torch.randint(S, T, (B,)) seq_length = torch.randint(3, S - 1, (B,)) T = torch.max(frames) S = torch.max(seq_length) am_ = torch.randn((B, T, C), dtype=torch.float32) lm_ = torch.randn((B, S + 1, C), dtype=torch.float32) symbols_ = torch.randint(0, C - 1, (B, S)) termination_symbol = C - 1 boundary_ = torch.zeros((B, 4), dtype=torch.int64) boundary_[:, 2] = seq_length boundary_[:, 3] = frames for device in self.devices: # lm: [B][S+1][C] lm = lm_.to(device) # am: [B][T][C] am = am_.to(device) symbols = symbols_.to(device) boundary = boundary_.to(device) logprobs = am.unsqueeze(2) + lm.unsqueeze(1) logprobs.requires_grad_() k2_loss = k2.rnnt_loss( logits=logprobs, symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, ) k2_grad = torch.autograd.grad(k2_loss, logprobs) k2_grad = k2_grad[0] logprobs2 = logprobs.detach().clone().float() logprobs2.requires_grad_() torch_loss = torchaudio.functional.rnnt_loss( logits=logprobs2, targets=symbols.int(), logit_lengths=boundary[:, 3].int(), target_lengths=boundary[:, 2].int(), blank=termination_symbol, ) torch_grad = torch.autograd.grad(torch_loss, logprobs2) torch_grad = torch_grad[0] assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2) assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2)
def test_rnnt_loss_basic(self): B = 1 S = 3 T = 4 # C = 3 for device in self.devices: # lm: [B][S+1][C] lm = torch.tensor( [[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], dtype=torch.float, device=device, ) # am: [B][T][C] am = torch.tensor( [[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]], dtype=torch.float, device=device, ) termination_symbol = 2 symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) px, py = k2.get_rnnt_logprobs( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, ) assert px.shape == (B, S, T + 1) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) m = k2.mutual_information_recursion(px=px, py=py, boundary=None) if device == torch.device("cpu"): expected = -m assert torch.allclose(-m, expected.to(device)) # test rnnt_loss_simple m = k2.rnnt_loss_simple( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, boundary=None, reduction="none", ) assert torch.allclose(m, expected.to(device)) # test rnnt_loss_smoothed m = k2.rnnt_loss_smoothed( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=None, reduction="none", ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) # test rnnt_loss m = k2.rnnt_loss( logits=probs, symbols=symbols, termination_symbol=termination_symbol, boundary=None, reduction="none", ) assert torch.allclose(m, expected.to(device)) # compare with torchaudio rnnt_loss if self.has_torch_rnnt_loss: import torchaudio.functional m = torchaudio.functional.rnnt_loss( logits=probs, targets=symbols.int(), logit_lengths=torch.tensor( [T] * B, dtype=torch.int32, device=device ), target_lengths=torch.tensor( [S] * B, dtype=torch.int32, device=device ), blank=termination_symbol, reduction="none", ) assert torch.allclose(m, expected.to(device)) # should be invariant to adding a constant for any frame. lm += torch.randn(B, S + 1, 1, device=device) am += torch.randn(B, T, 1, device=device) m = k2.rnnt_loss_simple( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, boundary=None, reduction="none", ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=None, reduction="none", ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) m = k2.rnnt_loss( logits=probs, symbols=symbols, termination_symbol=termination_symbol, boundary=None, reduction="none", ) assert torch.allclose(m, expected.to(device))
def test_rnnt_loss_pruned(self): B = 4 T = 300 S = 50 C = 10 frames = torch.randint(S, T, (B,)) seq_length = torch.randint(3, S - 1, (B,)) T = torch.max(frames) S = torch.max(seq_length) am_ = torch.randn((B, T, C), dtype=torch.float64) lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) symbols_ = torch.randint(0, C - 1, (B, S)) terminal_symbol = C - 1 boundary_ = torch.zeros((B, 4), dtype=torch.int64) boundary_[:, 2] = seq_length boundary_[:, 3] = frames for modified in [True, False]: for device in self.devices: # normal rnnt am = am_.to(device) lm = lm_.to(device) symbols = symbols_.to(device) boundary = boundary_.to(device) t_am = am.unsqueeze(2).float() t_lm = lm.unsqueeze(1).float() t_prob = t_am + t_lm # nonlinear transform t_prob = torch.sigmoid(t_prob) k2_loss = k2.rnnt_loss( logits=t_prob, symbols=symbols, termination_symbol=terminal_symbol, boundary=boundary, modified=modified, ) print( f"unpruned rnnt loss with modified {modified} : {k2_loss}" ) # pruning k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( lm=lm, am=am, symbols=symbols, termination_symbol=terminal_symbol, boundary=boundary, modified=modified, return_grad=True, reduction="none", ) for r in range(2, 50, 5): ranges = k2.get_rnnt_prune_ranges( px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=r, ) # (B, T, r, C) am_p, lm_p = k2.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) t_prob_p = am_p + lm_p # nonlinear transform t_prob_p = torch.sigmoid(t_prob_p) pruned_loss = k2.rnnt_loss_pruned( logits=t_prob_p, symbols=symbols, ranges=ranges, termination_symbol=terminal_symbol, boundary=boundary, modified=modified, reduction="none", ) print(f"pruning loss with range {r} : {pruned_loss}")
def test_rnnt_loss_random(self): B = 5 S = 20 T = 300 C = 100 frames = torch.randint(S, T, (B,)) seq_length = torch.randint(3, S - 1, (B,)) T = torch.max(frames) S = torch.max(seq_length) am_ = torch.randn((B, T, C), dtype=torch.float32) lm_ = torch.randn((B, S + 1, C), dtype=torch.float32) symbols_ = torch.randint(0, C - 1, (B, S)) termination_symbol = C - 1 boundary_ = torch.zeros((B, 4), dtype=torch.int64) boundary_[:, 2] = seq_length boundary_[:, 3] = frames for modified in [True, False]: for device in self.devices: # lm: [B][S+1][C] lm = lm_.to(device) # am: [B][T][C] am = am_.to(device) symbols = symbols_.to(device) boundary = boundary_.to(device) px, py = k2.get_rnnt_logprobs( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, modified=modified, ) assert px.shape == (B, S, T) if modified else (B, S, T + 1) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) m = k2.mutual_information_recursion( px=px, py=py, boundary=boundary ) if device == torch.device("cpu"): expected = -torch.mean(m) assert torch.allclose(-torch.mean(m), expected.to(device)) m = k2.rnnt_loss_simple( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, modified=modified, ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=boundary, modified=modified, ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) m = k2.rnnt_loss( logits=probs, symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, modified=modified, ) assert torch.allclose(m, expected.to(device)) # compare with torchaudio rnnt_loss if self.has_torch_rnnt_loss and not modified: import torchaudio.functional m = torchaudio.functional.rnnt_loss( logits=probs, targets=symbols.int(), logit_lengths=boundary[:, 3].int(), target_lengths=boundary[:, 2].int(), blank=termination_symbol, ) assert torch.allclose(m, expected.to(device)) # should be invariant to adding a constant for any frame. lm += torch.randn(B, S + 1, 1, device=device) am += torch.randn(B, T, 1, device=device) m = k2.rnnt_loss_simple( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, modified=modified, ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) m = k2.rnnt_loss( logits=probs, symbols=symbols, termination_symbol=termination_symbol, boundary=boundary, modified=modified, ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( lm=lm, am=am, symbols=symbols, termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=boundary, modified=modified, ) assert torch.allclose(m, expected.to(device))