def _intersect_calc_scores_mmi_exact( self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, ): device = dense_fsa_vec.device assert device == num_graphs.device and device == den_graph.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas den_graph = den_graph.clone() num_graphs = num_graphs.clone() num_den_graphs = k2.cat([num_graphs, den_graph]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graph_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack([num_graphs_indexes, den_graph_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index_fsa(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense( a_fsas=num_den_reordered_graphs, b_fsas=dense_fsa_vec, output_beam=self.intersect_conf.output_beam, a_to_b_map=a_to_b_map, seqframe_idx_name="seqframe_idx" if return_lats else None, ) num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] if return_lats: lat_slice = torch.arange(num_fsas, dtype=torch.int32).to(device) * 2 return ( num_tot_scores, den_tot_scores, k2.index_fsa(num_den_lats, lat_slice), k2.index_fsa(num_den_lats, lat_slice + 1), ) else: return num_tot_scores, den_tot_scores, None, None
def test_cat_fsa_vec(self): for device in self.devices: s = ''' 0 1 1 0.1 0 1 2 0.2 1 2 -1 0.3 2 ''' fsa1 = k2.Fsa.from_str(s).to(device) fsa1.tensor_attr1 = torch.tensor([1, 2, 3]).to(device) fsa1.tensor_attr2 = torch.tensor([4, 5, 6]).to(device) fsa1.non_tensor_attr1 = 'fsa1' fsa1.ragged_tensor_attr1 = \ k2.RaggedTensor('[[1 2] [] [3 4 5]]').to(device) fsa1.ragged_tensor_attr2 = \ k2.RaggedTensor('[[1 20] [30] [5]]').to(device) fsa2 = k2.Fsa.from_str(s).to(device) fsa2.tensor_attr1 = torch.tensor([10, 20, 30]).to(device) fsa2.tensor_attr3 = torch.tensor([40, 50, 60]).to(device) fsa2.non_tensor_attr1 = 'fsa' fsa2.non_tensor_attr2 = 'fsa2' fsa2.ragged_tensor_attr1 = \ k2.RaggedTensor('[[3] [4 5] [6 7]]').to(device) fsa2.ragged_tensor_attr3 = \ k2.RaggedTensor('[[1 0] [0] [-1]]').to(device) fsa_vec1 = k2.create_fsa_vec([fsa1]) fsa_vec2 = k2.create_fsa_vec([fsa2]) fsa_vec = k2.cat([fsa_vec1, fsa_vec2]) assert str(fsa_vec[0].arcs) == str(fsa1.arcs) assert str(fsa_vec[1].arcs) == str(fsa2.arcs) assert not hasattr(fsa_vec, 'tensor_attr2') assert not hasattr(fsa_vec, 'tensor_attr3') assert fsa_vec.non_tensor_attr1 == fsa1.non_tensor_attr1 assert fsa_vec.non_tensor_attr2 == fsa2.non_tensor_attr2 assert torch.all( torch.eq(fsa_vec.tensor_attr1, torch.tensor([1, 2, 3, 10, 20, 30]).to(device))) assert fsa_vec.ragged_tensor_attr1 == k2.RaggedTensor([ [1, 2], [], [3, 4, 5], [3], [4, 5], [6, 7], ]).to(device) assert not hasattr(fsa_vec, 'ragged_tensor_attr2') assert not hasattr(fsa_vec, 'ragged_tensor_attr3')
def _intersect_device( a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor, sorted_match_a: bool, batch_size: int = 500, ): """Wrap k2.intersect_device This is a wrapper of k2.intersect_device and its purpose is to split b_fsas into several batches and process each batch separately to avoid CUDA OOM error. The arguments and return value of this function are the same as k2.intersect_device. NOTE: You can decrease batch_size in case of CUDA out of memory error. """ num_fsas = b_fsas.shape[0] if num_fsas <= batch_size: return k2.intersect_device( a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a ) num_batches = int(math.ceil(float(num_fsas) / batch_size)) splits = [] for i in range(num_batches): start = i * batch_size end = min(start + batch_size, num_fsas) splits.append((start, end)) ans = [] for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) fsas = k2.index_fsa(b_fsas, indexes) b_to_a = k2.index_select(b_to_a_map, indexes) path_lats = k2.intersect_device( a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a ) ans.append(path_lats) return k2.cat(ans)
def _compute_mmi_loss_exact_optimized( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' The function name contains `exact`, which means it uses a version of intersection without pruning. `optimized` in the function name means this function is optimized in that it calls k2.intersect_dense only once Note: It is faster at the cost of using more memory. Args: nnet_output: A 3-D tensor of shape [N, T, C] texts: The transcript. Each element consists of space(s) separated words. supervision_segments: A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`. graph_compiler: Used to build num_graphs and den_graphs P: Represents a bigram Fsa. den_scale: The scale applied to the denominator tot_scores. ''' num_graphs, den_graphs = graph_compiler.compile(texts, P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) device = num_graphs.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas assert den_graphs.shape[0] == 1 # the aux_labels of num_graphs is k2.RaggedInt # but it is torch.Tensor for den_graphs. # # The following converts den_graphs.aux_labels # from torch.Tensor to k2.RaggedInt so that # we can use k2.append() later den_graphs.convert_attr_to_ragged_(name='aux_labels') # The motivation to concatenate num_graphs and den_graphs # is to reduce the number of calls to k2.intersect_dense. num_den_graphs = k2.cat([num_graphs, den_graphs]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # # The following code computes a_to_b_map # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack( [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense(num_den_reordered_graphs, dense_fsa_vec, output_beam=10.0, a_to_b_map=a_to_b_map) num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def forward( self, nnet_output: torch.Tensor, texts: List, supervision_segments: torch.Tensor ) -> Tuple[torch.Tensor, int, int]: num_graphs, den_graphs = self.graph_compiler.compile( texts, self.P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) device = num_graphs.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas assert den_graphs.shape[0] == 1 # the aux_labels of num_graphs is k2.RaggedInt # but it is torch.Tensor for den_graphs. # # The following converts den_graphs.aux_labels # from torch.Tensor to k2.RaggedInt so that # we can use k2.append() later den_graphs.convert_attr_to_ragged_(name='aux_labels') num_den_graphs = k2.cat([num_graphs, den_graphs]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack( [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense(num_den_reordered_graphs, dense_fsa_vec, output_beam=10.0, a_to_b_map=a_to_b_map) num_den_tot_scores = num_den_lats.get_tot_scores( log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] tot_scores = num_tot_scores - self.den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames