def _compute_mmi_loss_exact_non_optimized( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' See :func:`_compute_mmi_loss_exact_optimized` for the meaning of the arguments. It's more readable, though it invokes k2.intersect_dense twice. Note: It uses less memory at the cost of speed. It is slower. ''' num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=10.0) num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def forward(self, log_probs: torch.Tensor, targets: torch.Tensor, input_lengths: torch.Tensor, target_lengths: torch.Tensor) -> torch.Tensor: log_probs = log_probs.permute(1, 0, 2).cpu( ) # now log_probs is [N, T, C] batchSize x seqLength x alphabet_size supervision_segments = torch.stack( (torch.tensor(range(input_lengths.shape[0])), torch.zeros(input_lengths.shape[0]), input_lengths), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments) decoding_graph = self.graph_compiler.compile(targets.cpu(), target_lengths) decoding_graph = k2.index(decoding_graph, indices.to(torch.int32)).to(log_probs.device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = k2.get_tot_scores(target_graph, log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) return -tot_score
def test_two_fsas(self): s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb # expected_scores_out_fsa = torch.tensor( # [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0]) expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0, 1.0]) expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0]) print("fsa2 is ", fsa2.__str__()) # TODO(dan):: fix this.. # expected_grad_log_prob = torch.tensor([ # 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0, # 0.0, 0.0, 0.0, 1.0 # ]).reshape_as(log_prob) # assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad) assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad)
def _intersect_calc_scores_mmi_pruned( self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, ): device = dense_fsa_vec.device assert device == num_graphs.device and device == den_graph.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas num_lats = k2.intersect_dense( a_fsas=num_graphs, b_fsas=dense_fsa_vec, output_beam=self.intersect_conf.output_beam, seqframe_idx_name="seqframe_idx" if return_lats else None, ) den_lats = k2.intersect_dense_pruned( a_fsas=den_graph, b_fsas=dense_fsa_vec, search_beam=self.intersect_conf.search_beam, output_beam=self.intersect_conf.output_beam, min_active_states=self.intersect_conf.min_active_states, max_active_states=self.intersect_conf.max_active_states, seqframe_idx_name="seqframe_idx" if return_lats else None, ) # use_double_scores=True does matter # since otherwise it sometimes makes rounding errors num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) if return_lats: return num_tot_scores, den_tot_scores, num_lats, den_lats else: return num_tot_scores, den_tot_scores, None, None
def test_case1(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # suppose we have four symbols: <blk>, a, b, c, d torch_activation = torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2]).to(device) k2_activation = torch_activation.detach().clone() # (T, N, C) torch_activation = torch_activation.reshape( 1, 1, -1).requires_grad_(True) # (N, T, C) k2_activation = k2_activation.reshape(1, 1, -1).requires_grad_(True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its label is `a` targets = torch.tensor([1]).to(device) input_lengths = torch.tensor([1]).to(device) target_lengths = torch.tensor([1]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose(torch_loss, torch.tensor([1.6094379425049]).to(device)) # (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([1]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test_simple(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' fsa = k2.Fsa.from_str(s) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa]) log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]], dtype=torch.float32, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000) scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0]) expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0]) expected_grad_log_prob = torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).reshape_as(log_prob) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa, fsa.scores.grad) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def test_case3(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda')) for device in devices: # (T, N, C) torch_activation = torch.tensor([[ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]]).permute(1, 0, 2).to(device).requires_grad_(True) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `b,c` targets = torch.tensor([2, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([2, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([4.938850402832]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def _intersect_calc_scores_mmi_exact( self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, ): device = dense_fsa_vec.device assert device == num_graphs.device and device == den_graph.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas den_graph = den_graph.clone() num_graphs = num_graphs.clone() num_den_graphs = k2.cat([num_graphs, den_graph]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graph_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack([num_graphs_indexes, den_graph_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index_fsa(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense( a_fsas=num_den_reordered_graphs, b_fsas=dense_fsa_vec, output_beam=self.intersect_conf.output_beam, a_to_b_map=a_to_b_map, seqframe_idx_name="seqframe_idx" if return_lats else None, ) num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] if return_lats: lat_slice = torch.arange(num_fsas, dtype=torch.int32).to(device) * 2 return ( num_tot_scores, den_tot_scores, k2.index_fsa(num_den_lats, lat_slice), k2.index_fsa(num_den_lats, lat_slice + 1), ) else: return num_tot_scores, den_tot_scores, None, None
def test_random_case1(self): # 1 sequence devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: T = torch.randint(10, 100, (1,)).item() C = torch.randint(20, 30, (1,)).item() torch_activation = torch.rand((1, T + 10, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) # [N, T, C] -> [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activation.permute(1, 0, 2), dim=-1) input_lengths = torch.tensor([T]).to(device) target_lengths = torch.randint(1, T, (1,)).to(device) targets = torch.randint(1, C - 1, (target_lengths.item(),)).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') k2_log_probs = torch.nn.functional.log_softmax(k2_activation, dim=-1) supervision_segments = torch.tensor([[0, 0, T]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([targets.tolist()]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation.grad, k2_activation.grad, atol=1e-2)
def test_two_fsas_long(self): # as test_two_fsas, but generate long DenseFsaVec for easier profiling. s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.rand((2, 100, 3), dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 1, 95], [1, 20, 50]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) fsa_vec = fsa_vec.to(device) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000, seqframe_idx_name='seqframe', frame_idx_name='frame') expected_seqframe = torch.arange(96).to(torch.int32).to(device) assert torch.allclose(out_fsa.seqframe, expected_seqframe) # the second output FSA is empty since there is no self-loop in fsa2 assert torch.allclose(out_fsa.frame, expected_seqframe) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = out_fsa.get_tot_scores(log_semiring=False, use_double_scores=False) scores.sum().backward()
def test_case2(self): for device in self.devices: # (T, N, C) torch_activation = torch.arange(1, 16).reshape(1, 3, 5).permute( 1, 0, 2).to(device) torch_activation = torch_activation.to(torch.float32) torch_activation.requires_grad_(True) k2_activation = torch_activation.detach().clone().requires_grad_( True) torch_log_probs = torch.nn.functional.log_softmax( torch_activation, dim=-1) # (T, N, C) # we have only one sequence and its labels are `c,c` targets = torch.tensor([3, 3]).to(device) input_lengths = torch.tensor([3]).to(device) target_lengths = torch.tensor([2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') act = k2_activation.permute(1, 0, 2) # (T, N, C) -> (N, T, C) k2_log_probs = torch.nn.functional.log_softmax(act, dim=-1) supervision_segments = torch.tensor([[0, 0, 3]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) linear_fsa = k2.linear_fsa([3, 3]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) assert torch.allclose(torch_loss, torch.tensor([7.355742931366]).to(device)) torch_loss.backward() (-k2_scores).backward() assert torch.allclose(torch_activation.grad, k2_activation.grad)
def test_simple(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' for device in self.devices: fsa = k2.Fsa.from_str(s).to(device) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa]) log_prob = torch.tensor([[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06]]], dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000, seqframe_idx_name='seqframe', frame_idx_name='frame') assert torch.all( torch.eq(out_fsa.seqframe, torch.tensor([0, 1, 2], device=device))) assert torch.all( torch.eq(out_fsa.frame, torch.tensor([0, 1, 2], device=device))) scores = out_fsa.get_tot_scores(log_semiring=False, use_double_scores=False) scores.sum().backward() # `expected` results are computed using gtn. # See https://colab.research.google.com/drive/1FzEFjj5GoCDN2d05D9jE682CkR7QIlnm?usp=sharing expected_scores_out_fsa = torch.tensor([1.2, 2.06, 3.0], device=device) expected_grad_fsa = torch.tensor([1.0, 0.0, 1.0, 1.0], device=device) expected_grad_log_prob = torch.tensor( [0.0, 1.0, 0.0, 0.0, 0.0, 1.0], device=device).reshape_as(log_prob) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa, fsa.scores.grad) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def forward( self, nnet_output: torch.Tensor, texts: List, supervision_segments: torch.Tensor ) -> Tuple[torch.Tensor, int, int]: num_graphs = self.graph_compiler.compile(texts).to(nnet_output.device) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, 10.0) num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def test_two_fsas_long(self): # as test_two_fsas, but generate long DenseFsaVec for easier profiling. s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.rand((2, 500, 3), dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 490], [1, 0, 300]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) fsa_vec = fsa_vec.to(device) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = k2.get_tot_scores(out_fsa, log_semiring=False, use_float_scores=True) scores.sum().backward()
def align(self, cuts: Union[AnyCut, CutSet]) -> torch.Tensor: """ Perform forced alignment and return a tensor that represents a batch of frame-level alignments: >>> alignments = torch.tensor([ ... [0, 0, 0, 1, 57, 57, 35, 35, 35, ...], ... [...], ... ... ... ]) :return: an int32 tensor with shape ``(batch_size, num_frames)``. """ # Extract feats # (batch, seq_len, num_feats) if isinstance(cuts, (Cut, MixedCut)): cuts = CutSet.from_cuts([cuts]) assert cuts[ 0].sampling_rate == self.sampling_rate, f'{cuts[0].sampling_rate} != {self.sampling_rate}' cuts = cuts.map_supervisions(self.normalize_text) otf = OnTheFlyFeatures(self.extractor) feats, _ = otf(cuts) feats = feats.permute(0, 2, 1) texts = [' '.join(s.text for s in cut.supervisions) for cut in cuts] # Compute AM posteriors # (batch, seq_len ~/ 4, num_phones) posteriors, _, _ = self.model(feats) # Note: we are using "dummy" supervisions so that the aligner also considers # the padding area. We can adjust that behaviour if needed by passing actual # supervision segments, but then we will have a ragged tensor (will need to # pad the alignments themselves). sups = self.dummy_supervisions(feats) posteriors_fsa = k2.DenseFsaVec(posteriors.permute(0, 2, 1), sups) # Intersection with ground truth transcript graphs num, den = self.compiler.compile(texts, self.P) alignment = k2.intersect_dense(num, posteriors_fsa, output_beam=10.0) best_path = k2.shortest_path(alignment, use_double_scores=True) # Retrieve sequences of phone IDs per frame # (batch, seq_len ~/ 4) -- dtype int32 (num phone labels) frame_labels = torch.stack( [best_path[i].labels[:-1] for i in range(best_path.shape[0])]) return frame_labels
def _compute_mmi_loss_pruned( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' See :func:`_compute_mmi_loss_exact_optimized` for the meaning of the arguments. `pruned` means it uses k2.intersect_dense_pruned Note: It uses the least amount of memory, but the loss is not exact due to pruning. ''' num_graphs, den_graphs = graph_compiler.compile(texts, P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0) # the values for search_beam/output_beam/min_active_states/max_active_states # are not tuned. You may want to tune them. den_lats = k2.intersect_dense_pruned(den_graphs, dense_fsa_vec, search_beam=20.0, output_beam=7.0, min_active_states=30, max_active_states=10000) num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def forward( self, log_probs: torch.Tensor, targets: torch.Tensor, input_lengths: torch.Tensor, target_lengths: torch.Tensor, ) -> torch.Tensor: if self.blank != 0: # rearrange log_probs to put blank at the first place # and shift targets to emulate blank = 0 log_probs, targets = make_blank_first(self.blank, log_probs, targets) supervisions, order = create_supervision(input_lengths) order = order.long() targets = targets[order] target_lengths = target_lengths[order] # PyTorch is doing the log-softmax normalization as part of the CTC computation. # More: https://github.com/k2-fsa/k2/issues/575 log_probs = GradExpNormalize.apply( log_probs, input_lengths, "mean" if self.reduction != "sum" else "none") if log_probs.device != self.graph_compiler.device: self.graph_compiler.to(log_probs.device) num_graphs = self.graph_compiler.compile( targets + 1 if self.pad_fsavec else targets, target_lengths) dense_fsa_vec = (prep_padded_densefsavec(log_probs, supervisions) if self.pad_fsavec else k2.DenseFsaVec( log_probs, supervisions)) num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, torch.finfo(torch.float32).max) # use_double_scores=True does matter # since otherwise it sometimes makes rounding errors num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores tot_scores, valid_mask = get_tot_objf_and_finite_mask( tot_scores, self.reduction) return -tot_scores[valid_mask], valid_mask
def test_two_dense(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' for use_map in [True, False]: fsa = k2.Fsa.from_str(s) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa, fsa]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, requires_grad=True) if use_map: a_to_b_map = torch.tensor([0, 0], dtype=torch.int32) else: a_to_b_map = None supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000, a_to_b_map=a_to_b_map, seqframe_idx_name='seqframe', frame_idx_name='frame') if not use_map: assert torch.allclose( out_fsa.seqframe, torch.tensor([0, 1, 2, 3, 4, 5, 6], dtype=torch.int32)) assert torch.allclose( out_fsa.frame, torch.tensor([0, 1, 2, 3, 0, 1, 2], dtype=torch.int32)) else: assert torch.allclose( out_fsa.seqframe, torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32)) assert torch.allclose( out_fsa.frame, torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32)) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = out_fsa.get_tot_scores(log_semiring=False, use_double_scores=False) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb # expected_scores_out_fsa = torch.tensor( # [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0]) if not use_map: expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0]) else: expected_grad_fsa = torch.tensor([2.0, 2.0, 2.0, 2.0]) # expected_grad_log_prob = torch.tensor([ # 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, 1.0, # 0.0, 0.0, 0.0, 1.0 # ]).reshape_as(log_prob) assert torch.allclose(expected_grad_fsa, fsa.scores.grad)
def _compute_mmi_loss_exact_optimized( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' The function name contains `exact`, which means it uses a version of intersection without pruning. `optimized` in the function name means this function is optimized in that it calls k2.intersect_dense only once Note: It is faster at the cost of using more memory. Args: nnet_output: A 3-D tensor of shape [N, T, C] texts: The transcript. Each element consists of space(s) separated words. supervision_segments: A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`. graph_compiler: Used to build num_graphs and den_graphs P: Represents a bigram Fsa. den_scale: The scale applied to the denominator tot_scores. ''' num_graphs, den_graphs = graph_compiler.compile(texts, P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) device = num_graphs.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas assert den_graphs.shape[0] == 1 # the aux_labels of num_graphs is k2.RaggedInt # but it is torch.Tensor for den_graphs. # # The following converts den_graphs.aux_labels # from torch.Tensor to k2.RaggedInt so that # we can use k2.append() later den_graphs.convert_attr_to_ragged_(name='aux_labels') # The motivation to concatenate num_graphs and den_graphs # is to reduce the number of calls to k2.intersect_dense. num_den_graphs = k2.cat([num_graphs, den_graphs]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # # The following code computes a_to_b_map # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack( [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense(num_den_reordered_graphs, dense_fsa_vec, output_beam=10.0, a_to_b_map=a_to_b_map) num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, training: bool, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device # TODO(haowen): with a small `beam`, we may get empty `target_graph`, # thus `tot_scores` will be `inf`. Definitely we need to handle this later. target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = k2.get_tot_scores(target_graph, log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def get_objf(batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, is_training: bool, is_update: bool, accum_grad: int = 1, att_rate: float = 0.0, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], (((supervisions['start_frame'] - 1) // 2 - 1) // 2), (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) else: with torch.no_grad(): nnet_output, encoder_memory, memory_mask = model(feature, supervision_segments) if att_rate != 0.0: att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = target_graph.get_tot_scores( log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: if att_rate != 0.0: loss = (- (1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad) else: loss = (-tot_score) / (len(texts) * accum_grad) loss.backward() if is_update: clip_grad_value_(model.parameters(), 5.0) optimizer.step() optimizer.zero_grad() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def test_random_case2(self): # 2 sequences for device in self.devices: T1 = torch.randint(10, 200, (1, )).item() T2 = torch.randint(9, 100, (1, )).item() C = torch.randint(20, 30, (1, )).item() if T1 < T2: T1, T2 = T2, T1 torch_activation_1 = torch.rand((T1, C), dtype=torch.float32, device=device).requires_grad_(True) torch_activation_2 = torch.rand((T2, C), dtype=torch.float32, device=device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_1, torch_activation_2], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_1, k2_activation_2], batch_first=True, padding_value=0) target_length1 = torch.randint(1, T1, (1, )).item() target_length2 = torch.randint(1, T2, (1, )).item() target_lengths = torch.tensor([target_length1, target_length2]).to(device) targets = torch.randint(1, C - 1, (target_lengths.sum(), )).to(device) # [T, N, C] torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) input_lengths = torch.tensor([T1, T2]).to(device) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert T1 >= T2 supervision_segments = torch.tensor([[0, 0, T1], [1, 0, T2]], dtype=torch.int32) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo(list(range(C))).invert_()) linear_fsa = k2.linear_fsa([ targets[:target_length1].tolist(), targets[target_length1:].tolist() ]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.rand_like(torch_loss) * 100 (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad, atol=1e-2) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad, atol=1e-2)
def get_loss(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler, is_training: bool, optimizer: Optional[torch.optim.Optimizer] = None): assert P.device == device feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P) else: with torch.no_grad(): num_graph, den_graph, decoding_graph = graph_compiler.compile( texts, P) assert num_graph.requires_grad == is_training assert den_graph.requires_grad is False assert decoding_graph.requires_grad is False assert len( decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None) num_graph = num_graph.to(device) den_graph = den_graph.to(device) decoding_graph = decoding_graph.to(device) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num_lats = k2.intersect_dense(num_graph, dense_fsa_vec, 10.0, seqframe_idx_name='seqframe_idx') mbr_lats = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 20.0, 7.0, 30, 10000, seqframe_idx_name='seqframe_idx') if True: # WARNING: the else branch is not working at present (the total loss is not stable) den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0) else: # in this case, we can remove den_graph den_lats = mbr_lats num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) if id(den_lats) == id(mbr_lats): # Some entries in den_tot_scores may be -inf. # The corresponding sequences are discarded/ignored. finite_indexes = torch.isfinite(den_tot_scores) den_tot_scores = den_tot_scores[finite_indexes] num_tot_scores = num_tot_scores[finite_indexes] else: finite_indexes = None tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2], finite_indexes) num_rows = dense_fsa_vec.scores.shape[0] num_cols = dense_fsa_vec.scores.shape[1] - 1 mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx, cols=num_lats.phones, values=num_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx, cols=mbr_lats.phones, values=mbr_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) # NOTE: Due to limited support of PyTorch's autograd for sparse tensors, # we cannot use (mbr_num_sparse - mbr_den_sparse) here # # The following works only for torch >= 1.7.0 mbr_loss = torch.sparse.sum( k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce())) mmi_loss = -tot_score total_loss = mmi_loss + mbr_loss if is_training: optimizer.zero_grad() total_loss.backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = ( mmi_loss.detach().cpu().item(), mbr_loss.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item(), ) return ans
def forward( self, nnet_output: torch.Tensor, texts: List, supervision_segments: torch.Tensor ) -> Tuple[torch.Tensor, int, int]: num_graphs, den_graphs = self.graph_compiler.compile( texts, self.P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) device = num_graphs.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas assert den_graphs.shape[0] == 1 # the aux_labels of num_graphs is k2.RaggedInt # but it is torch.Tensor for den_graphs. # # The following converts den_graphs.aux_labels # from torch.Tensor to k2.RaggedInt so that # we can use k2.append() later den_graphs.convert_attr_to_ragged_(name='aux_labels') num_den_graphs = k2.cat([num_graphs, den_graphs]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack( [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense(num_den_reordered_graphs, dense_fsa_vec, output_beam=10.0, a_to_b_map=a_to_b_map) num_den_tot_scores = num_den_lats.get_tot_scores( log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] tot_scores = num_tot_scores - self.den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def test_case4(self): for device in self.devices: # put case3, case2 and case1 into a batch torch_activation_1 = torch.tensor( [[0., 0., 0., 0., 0.]]).to(device).requires_grad_(True) torch_activation_2 = torch.arange(1, 16).reshape(3, 5).to( torch.float32).to(device).requires_grad_(True) torch_activation_3 = torch.tensor([ [-5, -4, -3, -2, -1], [-10, -9, -8, -7, -6], [-15, -14, -13, -12, -11.], ]).to(device).requires_grad_(True) k2_activation_1 = torch_activation_1.detach().clone( ).requires_grad_(True) k2_activation_2 = torch_activation_2.detach().clone( ).requires_grad_(True) k2_activation_3 = torch_activation_3.detach().clone( ).requires_grad_(True) # [T, N, C] torch_activations = torch.nn.utils.rnn.pad_sequence( [torch_activation_3, torch_activation_2, torch_activation_1], batch_first=False, padding_value=0) # [N, T, C] k2_activations = torch.nn.utils.rnn.pad_sequence( [k2_activation_3, k2_activation_2, k2_activation_1], batch_first=True, padding_value=0) # [[b,c], [c,c], [a]] targets = torch.tensor([2, 3, 3, 3, 1]).to(device) input_lengths = torch.tensor([3, 3, 1]).to(device) target_lengths = torch.tensor([2, 2, 1]).to(device) torch_log_probs = torch.nn.functional.log_softmax( torch_activations, dim=-1) # (T, N, C) torch_loss = torch.nn.functional.ctc_loss( log_probs=torch_log_probs, targets=targets, input_lengths=input_lengths, target_lengths=target_lengths, reduction='none') assert torch.allclose( torch_loss, torch.tensor([4.938850402832, 7.355742931366, 1.6094379425049]).to(device)) k2_log_probs = torch.nn.functional.log_softmax(k2_activations, dim=-1) supervision_segments = torch.tensor( [[0, 0, 3], [1, 0, 3], [2, 0, 1]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(k2_log_probs, supervision_segments).to(device) ctc_topo_inv = k2.arc_sort( build_ctc_topo([0, 1, 2, 3, 4]).invert_()) # [ [b, c], [c, c], [a]] linear_fsa = k2.linear_fsa([[2, 3], [3, 3], [1]]) decoding_graph = k2.intersect(ctc_topo_inv, linear_fsa) decoding_graph = k2.connect(decoding_graph).invert_().to(device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 100.0) k2_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=False) assert torch.allclose(torch_loss, -1 * k2_scores) scale = torch.tensor([1., -2, 3.5]).to(device) (torch_loss * scale).sum().backward() (-k2_scores * scale).sum().backward() assert torch.allclose(torch_activation_1.grad, k2_activation_1.grad) assert torch.allclose(torch_activation_2.grad, k2_activation_2.grad) assert torch.allclose(torch_activation_3.grad, k2_activation_3.grad)
def test_two_dense(self): s = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' for device in self.devices: for use_map in [True, False]: fsa = k2.Fsa.from_str(s).to(device) fsa.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa, fsa]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, device=device, requires_grad=True) if use_map: a_to_b_map = torch.tensor([0, 0], dtype=torch.int32, device=device) else: a_to_b_map = None supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000, a_to_b_map=a_to_b_map, seqframe_idx_name='seqframe', frame_idx_name='frame') if not use_map: assert torch.all( torch.eq( out_fsa.seqframe, torch.tensor([0, 1, 2, 3, 4, 5, 6], device=device))) assert torch.all( torch.eq( out_fsa.frame, torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device))) else: assert torch.all( torch.eq( out_fsa.seqframe, torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], device=device))) assert torch.all( torch.eq( out_fsa.frame, torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], device=device))) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = out_fsa.get_tot_scores(log_semiring=False, use_double_scores=False) scores.sum().backward() # `expected` results are computed using gtn. # See https://colab.research.google.com/drive/1FzEFjj5GoCDN2d05D9jE682CkR7QIlnm?usp=sharing if not use_map: expected_scores_out_fsa = torch.tensor( [1.2, 50.05, 2.0, 3.0, 1.2, 2.6, 3.0], device=device) else: expected_scores_out_fsa = torch.tensor( [1.2, 50.05, 2.0, 3.0, 1.2, 50.05, 2.0, 3.0], device=device) assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) if not use_map: expected_grad_fsa = torch.tensor([2.0, 1.0, 2.0, 2.0], device=device) else: expected_grad_fsa = torch.tensor([2.0, 2.0, 2.0, 2.0], device=device) assert torch.allclose(expected_grad_fsa, fsa.scores.grad) if not use_map: expected_grad_log_prob = torch.tensor( [ 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0 ], device=device).reshape_as(log_prob) else: expected_grad_log_prob = torch.tensor( [ 0.0, 2.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ], device=device).reshape_as(log_prob) assert torch.allclose(expected_grad_log_prob, log_prob.grad)
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num, den = graph_compiler.compile(texts, P) else: with torch.no_grad(): num, den = graph_compiler.compile(texts, P) assert num.requires_grad == is_training assert den.requires_grad is False num = num.to(device) den = den.to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) num_tot_scores = num.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans
def test_two_fsas(self): s1 = ''' 0 1 1 1.0 1 1 1 50.0 1 2 2 2.0 2 3 -1 3.0 3 ''' s2 = ''' 0 1 1 1.0 1 2 2 2.0 2 3 -1 3.0 3 ''' for device in self.devices: fsa1 = k2.Fsa.from_str(s1).to(device) fsa2 = k2.Fsa.from_str(s2).to(device) fsa1.requires_grad_(True) fsa2.requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) log_prob = torch.tensor( [[[0.1, 0.2, 0.3], [0.04, 0.05, 0.06], [0.0, 0.0, 0.0]], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.0, 0.0, 0.0]]], dtype=torch.float32, device=device, requires_grad=True) supervision_segments = torch.tensor([[0, 0, 3], [1, 0, 2]], dtype=torch.int32) dense_fsa_vec = k2.DenseFsaVec(log_prob, supervision_segments) out_fsa = k2.intersect_dense(fsa_vec, dense_fsa_vec, output_beam=100000, seqframe_idx_name='seqframe', frame_idx_name='frame') assert torch.all( torch.eq(out_fsa.seqframe, torch.tensor([0, 1, 2, 3, 4, 5, 6], device=device))) assert torch.all( torch.eq(out_fsa.frame, torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device))) assert out_fsa.shape == (2, None, None), 'There should be two FSAs!' scores = out_fsa.get_tot_scores(log_semiring=False, use_double_scores=False) scores.sum().backward() # `expected` results are computed using gtn. # See https://bit.ly/3oYObeb # expected_scores_out_fsa = torch.tensor( # [1.2, 2.06, 3.0, 1.2, 50.5, 2.0, 3.0]) expected_grad_fsa1 = torch.tensor([1.0, 1.0, 1.0, 1.0], device=device) expected_grad_fsa2 = torch.tensor([1.0, 1.0, 1.0], device=device) # TODO(dan):: fix this.. # expected_grad_log_prob = torch.tensor([ # 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0, 0, 0, 0.0, 1.0, 0.0, 0.0, # 1.0, 0.0, 0.0, 0.0, 1.0 # ]).reshape_as(log_prob) # assert torch.allclose(out_fsa.scores, expected_scores_out_fsa) assert torch.allclose(expected_grad_fsa1, fsa1.scores.grad) assert torch.allclose(expected_grad_fsa2, fsa2.scores.grad)
def get_objf( batch: Dict, model: AcousticModel, device: torch.device, graph_compiler: CtcTrainingGraphCompiler, training: bool, optimizer: Optional[torch.optim.Optimizer] = None, ): feature = batch["inputs"] supervisions = batch["supervisions"] supervision_segments = torch.stack( ( supervisions["sequence_idx"], torch.floor_divide(supervisions["start_frame"], model.subsampling_factor), torch.floor_divide(supervisions["num_frames"], model.subsampling_factor), ), 1, ).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions["text"] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] decoding_graph = graph_compiler.compile(texts).to(device) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert decoding_graph.is_cuda() assert decoding_graph.device == device assert nnet_output.device == device target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = target_graph.get_tot_scores(log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if training: optimizer.zero_grad() (-tot_score).backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = ( -tot_score.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item(), ) return ans
def get_objf(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, tb_writer: Optional[SummaryWriter] = None, global_batch_idx_train: Optional[int] = None, optimizer: Optional[torch.optim.Optimizer] = None): feature = batch['features'] supervisions = batch['supervisions'] subsampling_factor = model.module.subsampling_factor if isinstance( model, DDP) else model.subsampling_factor supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], subsampling_factor), torch.floor_divide(supervisions['num_frames'], subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num, den = graph_compiler.compile(texts, P) else: with torch.no_grad(): num, den = graph_compiler.compile(texts, P) assert num.requires_grad == is_training assert den.requires_grad is False num = num.to(device) den = den.to(device) # nnet_output2 = nnet_output.clone() # blank_bias = -7.0 # nnet_output2[:,:,0] += blank_bias dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num = k2.intersect_dense(num, dense_fsa_vec, 10.0) den = k2.intersect_dense(den, dense_fsa_vec, 10.0) num_tot_scores = num.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den.get_tot_scores(log_semiring=True, use_double_scores=True) tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) if is_training: def maybe_log_gradients(tag: str): if (tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0): tb_writer.add_scalars(tag, measure_gradient_norms(model, norm='l1'), global_step=global_batch_idx_train) optimizer.zero_grad() (-tot_score).backward() maybe_log_gradients('train/grad_norms') clip_grad_value_(model.parameters(), 5.0) maybe_log_gradients('train/clipped_grad_norms') if tb_writer is not None and global_batch_idx_train % 200 == 0: # Once in a time we will perform a more costly diagnostic # to check the relative parameter change per minibatch. deltas = optim_step_and_measure_param_change(model, optimizer) tb_writer.add_scalars('train/relative_param_change_per_minibatch', deltas, global_step=global_batch_idx_train) else: optimizer.step() ans = -tot_score.detach().cpu().item(), tot_frames.cpu().item( ), all_frames.cpu().item() return ans