def test_create_sparse(self): s = ''' 0 1 10 0.1 0 1 11 0.2 1 2 20 0.3 2 3 21 0.4 2 3 24 0.5 3 4 -1 0.6 4 ''' for device in self.devices: fsa = k2.Fsa.from_str(s).to(device) fsa.phones = torch.tensor([10, 11, 20, 21, 24, -1], dtype=torch.int32, device=device) fsa.seqframes = torch.tensor([0, 0, 1, 2, 2, 3], dtype=torch.int32, device=device) fsa.requires_grad_(True) tensor = k2.create_sparse(rows=fsa.seqframes, cols=fsa.phones, values=fsa.scores, size=(6, 25), min_col_index=0) assert tensor.device == device assert tensor.is_sparse assert torch.allclose(tensor._indices()[0], fsa.seqframes[:-1].to(torch.int64)) assert torch.allclose(tensor._indices()[1], fsa.phones[:-1].to(torch.int64)) assert torch.allclose(tensor._values(), fsa.scores[:-1]) assert tensor.requires_grad == fsa.requires_grad assert tensor.dtype == fsa.scores.dtype
def create_sparse_wrapped( indices: List[torch.Tensor], values: torch.Tensor, size: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, min_col_index: Optional[int] = None, ) -> torch.Tensor: """Wraps up k2.create_sparse to create 2- or 3-dimensional sparse tensors. """ assert size is None or len(indices) == len(size) if len(indices) == 2: return k2.create_sparse( rows=indices[0], cols=indices[1], values=values, size=size, min_col_index=min_col_index, ) elif len(indices) == 3: assert indices[0].ndim == indices[1].ndim == indices[2].ndim == 1 assert indices[0].numel() == indices[1].numel() == indices[2].numel() == values.numel() if min_col_index is not None: assert isinstance(min_col_index, int) kept_indices = indices[-1] >= min_col_index indices = [i[kept_indices] for i in indices] values = values[kept_indices] if size is not None: return torch.sparse_coo_tensor( torch.stack(indices), values, size=size, device=values.device, requires_grad=values.requires_grad, ) else: return torch.sparse_coo_tensor( torch.stack(indices), values, device=values.device, requires_grad=values.requires_grad, ) else: raise ValueError(f"len(indices) = {len(indices)}")
def get_loss(batch: Dict, model: AcousticModel, P: k2.Fsa, device: torch.device, graph_compiler: MmiMbrTrainingGraphCompiler, is_training: bool, optimizer: Optional[torch.optim.Optimizer] = None): assert P.device == device feature = batch['features'] supervisions = batch['supervisions'] supervision_segments = torch.stack( (supervisions['sequence_idx'], torch.floor_divide(supervisions['start_frame'], model.subsampling_factor), torch.floor_divide(supervisions['num_frames'], model.subsampling_factor)), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] assert feature.ndim == 3 # print(supervision_segments[:, 1] + supervision_segments[:, 2]) feature = feature.to(device) # at entry, feature is [N, T, C] feature = feature.permute(0, 2, 1) # now feature is [N, C, T] if is_training: nnet_output = model(feature) else: with torch.no_grad(): nnet_output = model(feature) # nnet_output is [N, C, T] nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] if is_training: num_graph, den_graph, decoding_graph = graph_compiler.compile(texts, P) else: with torch.no_grad(): num_graph, den_graph, decoding_graph = graph_compiler.compile( texts, P) assert num_graph.requires_grad == is_training assert den_graph.requires_grad is False assert decoding_graph.requires_grad is False assert len( decoding_graph.shape) == 2 or decoding_graph.shape == (1, None, None) num_graph = num_graph.to(device) den_graph = den_graph.to(device) decoding_graph = decoding_graph.to(device) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) assert nnet_output.device == device num_lats = k2.intersect_dense(num_graph, dense_fsa_vec, 10.0, seqframe_idx_name='seqframe_idx') mbr_lats = k2.intersect_dense_pruned(decoding_graph, dense_fsa_vec, 20.0, 7.0, 30, 10000, seqframe_idx_name='seqframe_idx') if True: # WARNING: the else branch is not working at present (the total loss is not stable) den_lats = k2.intersect_dense(den_graph, dense_fsa_vec, 10.0) else: # in this case, we can remove den_graph den_lats = mbr_lats num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) if id(den_lats) == id(mbr_lats): # Some entries in den_tot_scores may be -inf. # The corresponding sequences are discarded/ignored. finite_indexes = torch.isfinite(den_tot_scores) den_tot_scores = den_tot_scores[finite_indexes] num_tot_scores = num_tot_scores[finite_indexes] else: finite_indexes = None tot_scores = num_tot_scores - den_scale * den_tot_scores (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2], finite_indexes) num_rows = dense_fsa_vec.scores.shape[0] num_cols = dense_fsa_vec.scores.shape[1] - 1 mbr_num_sparse = k2.create_sparse(rows=num_lats.seqframe_idx, cols=num_lats.phones, values=num_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) mbr_den_sparse = k2.create_sparse(rows=mbr_lats.seqframe_idx, cols=mbr_lats.phones, values=mbr_lats.get_arc_post(True, True).exp(), size=(num_rows, num_cols), min_col_index=0) # NOTE: Due to limited support of PyTorch's autograd for sparse tensors, # we cannot use (mbr_num_sparse - mbr_den_sparse) here # # The following works only for torch >= 1.7.0 mbr_loss = torch.sparse.sum( k2.sparse.abs((mbr_num_sparse + (-mbr_den_sparse)).coalesce())) mmi_loss = -tot_score total_loss = mmi_loss + mbr_loss if is_training: optimizer.zero_grad() total_loss.backward() clip_grad_value_(model.parameters(), 5.0) optimizer.step() ans = ( mmi_loss.detach().cpu().item(), mbr_loss.detach().cpu().item(), tot_frames.cpu().item(), all_frames.cpu().item(), ) return ans