def test_index_fsa(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: s1 = ''' 0 1 1 0.1 1 2 -1 0.2 2 ''' s2 = ''' 0 1 -1 1.0 1 ''' fsa1 = k2.Fsa.from_str(s1) fsa1.tensor_attr = torch.tensor([10, 20], dtype=torch.int32) fsa1.ragged_attr = k2.ragged.create_ragged2([[11, 12], [21, 22, 23]]) fsa2 = k2.Fsa.from_str(s2) fsa2.tensor_attr = torch.tensor([100], dtype=torch.int32) fsa2.ragged_attr = k2.ragged.create_ragged2([[111]]) fsa1 = fsa1.to(device) fsa2 = fsa2.to(device) fsa_vec = k2.create_fsa_vec([fsa1, fsa2]) single1 = k2.index_fsa( fsa_vec, torch.tensor([0], dtype=torch.int32, device=device)) assert torch.all(torch.eq(fsa1.tensor_attr, single1.tensor_attr)) assert str(single1.ragged_attr) == str(fsa1.ragged_attr) assert single1.device == device single2 = k2.index_fsa( fsa_vec, torch.tensor([1], dtype=torch.int32, device=device)) assert torch.all(torch.eq(fsa2.tensor_attr, single2.tensor_attr)) assert str(single2.ragged_attr) == str(fsa2.ragged_attr) assert single2.device == device multiples = k2.index_fsa( fsa_vec, torch.tensor([0, 1, 0, 1, 1], dtype=torch.int32, device=device)) assert multiples.shape == (5, None, None) assert torch.all( torch.eq( multiples.tensor_attr, torch.cat( (fsa1.tensor_attr, fsa2.tensor_attr, fsa1.tensor_attr, fsa2.tensor_attr, fsa2.tensor_attr)))) assert str(multiples.ragged_attr) == str( k2.ragged.append([ fsa1.ragged_attr, fsa2.ragged_attr, fsa1.ragged_attr, fsa2.ragged_attr, fsa2.ragged_attr ], axis=0)) # noqa assert multiples.device == device
def _intersect_calc_scores_mmi_exact( self, dense_fsa_vec: k2.DenseFsaVec, num_graphs: 'k2.Fsa', den_graph: 'k2.Fsa', return_lats: bool = True, ): device = dense_fsa_vec.device assert device == num_graphs.device and device == den_graph.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas den_graph = den_graph.clone() num_graphs = num_graphs.clone() num_den_graphs = k2.cat([num_graphs, den_graph]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graph_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack([num_graphs_indexes, den_graph_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index_fsa(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense( a_fsas=num_den_reordered_graphs, b_fsas=dense_fsa_vec, output_beam=self.intersect_conf.output_beam, a_to_b_map=a_to_b_map, seqframe_idx_name="seqframe_idx" if return_lats else None, ) num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] if return_lats: lat_slice = torch.arange(num_fsas, dtype=torch.int32).to(device) * 2 return ( num_tot_scores, den_tot_scores, k2.index_fsa(num_den_lats, lat_slice), k2.index_fsa(num_den_lats, lat_slice + 1), ) else: return num_tot_scores, den_tot_scores, None, None
def compile(self, texts: Iterable[str], P: k2.Fsa, replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. replicate_den: If True, the returned den_graph is replicated to match the number of FSAs in the returned num_graph; if False, the returned den_graph contains only a single FSA Returns: A tuple (num_graph, den_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph` if replicate_den is True; otherwise, it is an FsaVec containing only a single FSA. ''' assert P.device == self.device P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(self.ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() ctc_topo_P = k2.arc_sort(ctc_topo_P) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) num = k2.compose(ctc_topo_P, num_graphs_with_self_loops, treat_epsilons_specially=False) num = k2.arc_sort(num) ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) if replicate_den: indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) den = k2.index_fsa(ctc_topo_P_vec, indexes) else: den = ctc_topo_P_vec return num, den
def compile(self, texts: Iterable[str], P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa, k2.Fsa]: '''Create numerator and denominator graphs from transcripts and the bigram phone LM. Args: texts: A list of transcripts. Within a transcript, words are separated by spaces. P: The bigram phone LM created by :func:`create_bigram_phone_lm`. Returns: A tuple (num_graph, den_graph, decoding_graph), where - `num_graph` is the numerator graph. It is an FsaVec with shape `(len(texts), None, None)`. It is the result of compose(ctc_topo, P, L, transcript) - `den_graph` is the denominator graph. It is an FsaVec with the same shape of the `num_graph`. It is the result of compose(ctc_topo, P). - decoding_graph: It is the result of compose(ctc_topo, L_disambig, G) Note that it is a single Fsa, not an FsaVec. ''' assert P.device == self.device P_with_self_loops = k2.add_epsilon_self_loops(P) ctc_topo_P = k2.intersect(self.ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False).invert() ctc_topo_P = k2.arc_sort(ctc_topo_P) num_graphs = self.build_num_graphs(texts) num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( num_graphs) num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) num = k2.compose(ctc_topo_P, num_graphs_with_self_loops, treat_epsilons_specially=False, inner_labels='phones') num = k2.arc_sort(num) ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) den = k2.index_fsa(ctc_topo_P_vec, indexes) return num, den, self.decoding_graph
def _intersect_device( a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor, sorted_match_a: bool, batch_size: int = 500, ): """Wrap k2.intersect_device This is a wrapper of k2.intersect_device and its purpose is to split b_fsas into several batches and process each batch separately to avoid CUDA OOM error. The arguments and return value of this function are the same as k2.intersect_device. NOTE: You can decrease batch_size in case of CUDA out of memory error. """ num_fsas = b_fsas.shape[0] if num_fsas <= batch_size: return k2.intersect_device( a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a ) num_batches = int(math.ceil(float(num_fsas) / batch_size)) splits = [] for i in range(num_batches): start = i * batch_size end = min(start + batch_size, num_fsas) splits.append((start, end)) ans = [] for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) fsas = k2.index_fsa(b_fsas, indexes) b_to_a = k2.index_select(b_to_a_map, indexes) path_lats = k2.intersect_device( a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a ) ans.append(path_lats) return k2.cat(ans)
def top_k(self, k: int) -> 'Nbest': '''Get a subset of paths in the Nbest. The resulting Nbest is regular in that each sequence (i.e., utterance) has the same number of paths (k). We select the top-k paths according to the total_scores of each path. If a utterance has less than k paths, then its last path, after sorting by tot_scores in descending order, is repeated so that each utterance has exactly k paths. Args: k: Number of paths in each utterance. Returns: Return a new Nbest with a regular shape. ''' ragged_scores = self.total_scores() # indexes contains idx01's for self.shape # ragged_scores.values()[indexes] is sorted indexes = k2.ragged.sort_sublist(ragged_scores, descending=True, need_new2old_indexes=True) ragged_indexes = k2.RaggedInt(self.shape, indexes) padded_indexes = k2.ragged.pad(ragged_indexes, mode='replicate', value=-1) assert torch.ge(padded_indexes, 0).all(), \ 'Some utterances contain empty ' \ f'n-best: {self.shape.row_splits(1)}' # Select the idx01's of top-k paths of each utterance top_k_indexes = padded_indexes[:, :k].flatten().contiguous() top_k_fsas = k2.index_fsa(self.fsa, top_k_indexes) top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(), dim1=k) return Nbest(top_k_fsas, top_k_shape)
def test(self): s0 = ''' 0 1 1 0.1 0 2 2 0.2 1 2 3 0.3 2 3 -1 0.4 3 ''' s1 = ''' 0 1 -1 0.5 1 ''' s2 = ''' 0 2 1 0.6 0 1 2 0.7 1 3 -1 0.8 2 1 3 0.9 3 ''' for device in self.devices: fsa0 = k2.Fsa.from_str(s0).to(device).requires_grad_(True) fsa1 = k2.Fsa.from_str(s1).to(device).requires_grad_(True) fsa2 = k2.Fsa.from_str(s2).to(device).requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2]) new_fsa21 = k2.index_fsa( fsa_vec, torch.tensor([2, 1], dtype=torch.int32, device=device)) assert new_fsa21.shape == (2, None, None) assert torch.all( torch.eq( new_fsa21.arcs.values()[:, :3], torch.tensor([ # fsa 2 [0, 2, 1], [0, 1, 2], [1, 3, -1], [2, 1, 3], # fsa 1 [0, 1, -1] ]).to(torch.int32).to(device))) scale = torch.arange(new_fsa21.scores.numel(), device=device) (new_fsa21.scores * scale).sum().backward() assert torch.allclose(fsa0.scores.grad, torch.tensor([0., 0, 0, 0], device=device)) assert torch.allclose(fsa1.scores.grad, torch.tensor([4.], device=device)) assert torch.allclose( fsa2.scores.grad, torch.tensor([0., 1., 2., 3.], device=device)) # now select only a single FSA fsa0.scores.grad = None fsa1.scores.grad = None fsa2.scores.grad = None new_fsa0 = k2.index_fsa( fsa_vec, torch.tensor([0], dtype=torch.int32, device=device)) assert new_fsa0.shape == (1, None, None) scale = torch.arange(new_fsa0.scores.numel(), device=device) (new_fsa0.scores * scale).sum().backward() assert torch.allclose( fsa0.scores.grad, torch.tensor([0., 1., 2., 3.], device=device)) assert torch.allclose(fsa1.scores.grad, torch.tensor([0.], device=device)) assert torch.allclose( fsa2.scores.grad, torch.tensor([0., 0., 0., 0.], device=device))
def decode( self, log_probs: torch.Tensor, log_probs_length: torch.Tensor, return_lattices: bool = False, return_ilabels: bool = False, output_aligned: bool = True, ) -> Union['k2.Fsa', Tuple[List[torch.Tensor], List[torch.Tensor]]]: if self.decoding_graph is None: self.decoding_graph = self.base_graph if self.blank != 0: # rearrange log_probs to put blank at the first place # and shift targets to emulate blank = 0 log_probs, _ = make_blank_first(self.blank, log_probs, None) supervisions, order = create_supervision(log_probs_length) if self.decoding_graph.shape[0] > 1: self.decoding_graph = k2.index_fsa(self.decoding_graph, order).to(device=log_probs.device) if log_probs.device != self.device: self.to(log_probs.device) dense_fsa_vec = ( prep_padded_densefsavec(log_probs, supervisions) if self.pad_fsavec else k2.DenseFsaVec(log_probs, supervisions) ) if self.intersect_pruned: lats = k2.intersect_dense_pruned( a_fsas=self.decoding_graph, b_fsas=dense_fsa_vec, search_beam=self.intersect_conf.search_beam, output_beam=self.intersect_conf.output_beam, min_active_states=self.intersect_conf.min_active_states, max_active_states=self.intersect_conf.max_active_states, ) else: indices = torch.zeros(dense_fsa_vec.dim0(), dtype=torch.int32, device=self.device) dec_graphs = ( k2.index_fsa(self.decoding_graph, indices) if self.decoding_graph.shape[0] == 1 else self.decoding_graph ) lats = k2.intersect_dense(dec_graphs, dense_fsa_vec, self.intersect_conf.output_beam) if self.pad_fsavec: shift_labels_inpl([lats], -1) self.decoding_graph = None if return_lattices: lats = k2.index_fsa(lats, invert_permutation(order).to(device=log_probs.device)) if self.blank != 0: # change only ilabels # suppose self.blank == self.num_classes - 1 lats.labels = torch.where(lats.labels == 0, self.blank, lats.labels - 1) return lats else: shortest_path_fsas = k2.index_fsa( k2.shortest_path(lats, True), invert_permutation(order).to(device=log_probs.device), ) shortest_paths = [] probs = [] # direct iterating does not work as expected for i in range(shortest_path_fsas.shape[0]): shortest_path_fsa = shortest_path_fsas[i] labels = ( shortest_path_fsa.labels[:-1].to(dtype=torch.long) if return_ilabels else shortest_path_fsa.aux_labels[:-1].to(dtype=torch.long) ) if self.blank != 0: # suppose self.blank == self.num_classes - 1 labels = torch.where(labels == 0, self.blank, labels - 1) if not return_ilabels and not output_aligned: labels = labels[labels != self.blank] shortest_paths.append(labels[::2] if self.pad_fsavec else labels) probs.append(get_arc_weights(shortest_path_fsa)[:-1].to(device=log_probs.device).exp()) return shortest_paths, probs