def test(self): s0 = ''' 0 1 1 0.1 0 2 2 0.2 1 2 3 0.3 2 3 -1 0.4 3 ''' s1 = ''' 0 1 -1 0.5 1 ''' s2 = ''' 0 2 1 0.6 0 1 2 0.7 1 3 -1 0.8 2 1 3 0.9 3 ''' fsa0 = k2.Fsa.from_str(s0).requires_grad_(True) fsa1 = k2.Fsa.from_str(s1).requires_grad_(True) fsa2 = k2.Fsa.from_str(s2).requires_grad_(True) fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2]) new_fsa21 = k2.index(fsa_vec, torch.tensor([2, 1], dtype=torch.int32)) assert new_fsa21.shape == (2, None, None) assert torch.allclose( new_fsa21.arcs.values()[:, :3], torch.tensor([ # fsa 2 [0, 2, 1], [0, 1, 2], [1, 3, -1], [2, 1, 3], # fsa 1 [0, 1, -1] ]).to(torch.int32)) scale = torch.arange(new_fsa21.scores.numel()) (new_fsa21.scores * scale).sum().backward() assert torch.allclose(fsa0.scores.grad, torch.tensor([0., 0, 0, 0])) assert torch.allclose(fsa1.scores.grad, torch.tensor([4.])) assert torch.allclose(fsa2.scores.grad, torch.tensor([0., 1., 2., 3.])) # now select only a single FSA fsa0.scores.grad = None fsa1.scores.grad = None fsa2.scores.grad = None new_fsa0 = k2.index(fsa_vec, torch.tensor([0], dtype=torch.int32)) assert new_fsa0.shape == (1, None, None) scale = torch.arange(new_fsa0.scores.numel()) (new_fsa0.scores * scale).sum().backward() assert torch.allclose(fsa0.scores.grad, torch.tensor([0., 1., 2., 3.])) assert torch.allclose(fsa1.scores.grad, torch.tensor([0.])) assert torch.allclose(fsa2.scores.grad, torch.tensor([0., 0., 0., 0.]))
def test(self): devices = [torch.device('cpu')] if torch.cuda.is_available(): devices.append(torch.device('cuda', 0)) for device in devices: src_row_splits = torch.tensor([0, 2, 3, 3, 6], dtype=torch.int32, device=device) src_shape = k2.ragged.create_ragged_shape2(src_row_splits, None, 6) src_values = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32, device=device) src = k2.RaggedInt(src_shape, src_values) # index with ragged int index_row_splits = torch.tensor([0, 2, 2, 3, 7], dtype=torch.int32, device=device) index_shape = k2.ragged.create_ragged_shape2( index_row_splits, None, 7) index_values = torch.tensor([0, 3, 2, 1, 2, 1, 0], dtype=torch.int32, device=device) ragged_index = k2.RaggedInt(index_shape, index_values) ans = k2.index(src, ragged_index) expected_row_splits = torch.tensor([0, 5, 5, 5, 9], dtype=torch.int32, device=device) self.assertTrue( torch.allclose(ans.row_splits(1), expected_row_splits)) expected_values = torch.tensor([1, 2, 4, 5, 6, 3, 3, 1, 2], dtype=torch.int32, device=device) self.assertTrue(torch.allclose(ans.values(), expected_values)) # index with tensor tensor_index = torch.tensor([0, 3, 2, 1, 2, 1], dtype=torch.int32, device=device) ans = k2.index(src, tensor_index) expected_row_splits = torch.tensor([0, 2, 5, 5, 6, 6, 7], dtype=torch.int32, device=device) self.assertTrue( torch.allclose(ans.row_splits(1), expected_row_splits)) expected_values = torch.tensor([1, 2, 4, 5, 6, 3, 3], dtype=torch.int32, device=device) self.assertTrue(torch.allclose(ans.values(), expected_values))
def fsa_from_unary_function_ragged(src: Fsa, dest_arcs: _k2.RaggedArc, arc_map: _k2.RaggedInt) -> Fsa: '''Create an Fsa object, including autograd logic and propagating properties from the source FSA. This is intended to be called from unary functions on FSAs where the arc_map is an instance of _k2.RaggedInt. Args: src: The source Fsa, i.e. the arg to the unary function. dest_arcs: The raw output of the unary function, as output by whatever C++ algorithm we used. arc_map: A map from arcs in `dest_arcs` to the corresponding arc-index in `src`, or -1 if the arc had no source arc (e.g. :func:`remove_epsilon`). Returns: Returns the resulting Fsa, with properties propagated appropriately, and autograd handled. ''' dest = Fsa(dest_arcs) for name, value in src.named_tensor_attr(include_scores=False): setattr(dest, name, k2.index(value, arc_map)) for name, value in src.named_non_tensor_attr(): setattr(dest, name, value) k2.autograd_utils.phantom_index_and_sum_scores(dest, src.scores, arc_map) return dest
def forward(self, log_probs: torch.Tensor, targets: torch.Tensor, input_lengths: torch.Tensor, target_lengths: torch.Tensor) -> torch.Tensor: log_probs = log_probs.permute(1, 0, 2).cpu( ) # now log_probs is [N, T, C] batchSize x seqLength x alphabet_size supervision_segments = torch.stack( (torch.tensor(range(input_lengths.shape[0])), torch.zeros(input_lengths.shape[0]), input_lengths), 1).to(torch.int32) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] dense_fsa_vec = k2.DenseFsaVec(log_probs, supervision_segments) decoding_graph = self.graph_compiler.compile(targets.cpu(), target_lengths) decoding_graph = k2.index(decoding_graph, indices.to(torch.int32)).to(log_probs.device) target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0) tot_scores = k2.get_tot_scores(target_graph, log_semiring=True, use_double_scores=True) (tot_score, tot_frames, all_frames) = get_tot_objf_and_num_frames(tot_scores, supervision_segments[:, 2]) return -tot_score
def forward(ctx, fsas: Fsa, out_fsa: List[Fsa], unused_fsas_scores: torch.Tensor) -> torch.Tensor: '''Compute the union of all fsas in a FsaVec. Args: fsas: The input FsaVec. Caution: We require that each fsa in the FsaVec is non-empty (i.e., with at least two states). out_fsa: A list containing one entry. Since this function can only return values of type `torch.Tensor`, we return the union result in the list. unused_fsas_scores: It is the same as `fsas.scores`, whose sole purpose is for autograd. It is not used in this function. ''' need_arc_map = True ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map) out_fsa[0] = Fsa(ragged_arc) for name, value in fsas.named_tensor_attr(include_scores=False): value = k2.index(value, arc_map) setattr(out_fsa[0], name, value) for name, value in fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, value) ctx.arc_map = arc_map ctx.save_for_backward(unused_fsas_scores) return out_fsa[0].scores # the return value will be discarded
def expand_ragged_attributes( fsas: Fsa, ret_arc_map: bool = False ) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]: # noqa ''' Turn ragged labels attached to this FSA into linear (Tensor) labels, expanding arcs into sequences of arcs as necessary to achieve this. Supports autograd. If `fsas` had no ragged attributes, returns `fsas` itself. ret_arc_map: if true, will return a pair (new_fsas, arc_map) with `arc_map` a tensor of int32 that maps from arcs in the result to arcs in `fsas`, with -1's for newly created arcs. If false, just returns new_fsas. ''' ragged_attribute_tensors = [] ragged_attribute_names = [] for name, value in fsas.named_tensor_attr(include_scores=False): if isinstance(value, k2.RaggedInt): ragged_attribute_tensors.append(value) ragged_attribute_names.append(name) if len(ragged_attribute_tensors) == 0: if ret_arc_map: arc_map = torch.arange(fsas.num_arcs, dtype=torch.int32, device=fsas.device) return (fsas, arc_map) else: return fsas (dest_arcs, dest_labels, arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors) # The rest of this function is a modified version of # `fsa_from_unary_function_tensor()`. dest = Fsa(dest_arcs) # Handle the non-ragged attributes for name, value in fsas.named_tensor_attr(include_scores=False): if not isinstance(value, k2.RaggedInt): setattr(dest, name, k2.index(value, arc_map)) # Handle the attributes that were ragged but are now linear for name, value in zip(ragged_attribute_names, dest_labels): setattr(dest, name, value) # Copy non-tensor attributes for name, value in fsas.named_non_tensor_attr(): setattr(dest, name, value) # make sure autograd works on the scores k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map) if ret_arc_map: return dest, arc_map else: return dest
def ctc_graph(symbols: Union[List[List[int]], k2.RaggedInt], modified: bool = False, device: Optional[Union[torch.device, str]] = None) -> Fsa: '''Construct ctc graphs from symbols. Note: The scores of arcs in the returned FSA are all 0. Args: symbols: It can be one of the following types: - A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]` - An instance of :class:`k2.RaggedInt`. Must have `num_axes() == 2`. standard: Option to specify the type of CTC topology: "standard" or "simplified", where the "standard" one makes the blank mandatory between a pair of identical symbols. Default True. device: Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a torch.device. If it is None, then the returned FSA is on CPU. It has to be None if `symbols` is an instance of :class:`k2.RaggedInt`, the returned FSA will on the same device as `k2.RaggedInt`. Returns: An FsaVec containing the returned ctc graphs, with "Dim0()" the same as "len(symbols)"(List[List[int]]) or "Dim0()"(k2.RaggedInt) ''' if device is not None: device = torch.device(device) if device.type == 'cpu': gpu_id = -1 else: assert device.type == 'cuda' gpu_id = getattr(device, 'index', 0) else: gpu_id = -1 symbol_values = None if isinstance(symbols, k2.RaggedInt): assert device is None assert symbols.num_axes() == 2 symbol_values = symbols.values() else: symbol_values = torch.tensor( [it for symbol in symbols for it in symbol], dtype=torch.int32, device=device) need_arc_map = True ragged_arc, arc_map = _k2.ctc_graph(symbols, gpu_id, modified, need_arc_map) aux_labels = k2.index(symbol_values, arc_map) fsa = Fsa(ragged_arc, aux_labels=aux_labels) return fsa
def generate_nbest_list(lats: Fsa, num_paths: int) -> Nbest: '''Generate an n-best list from a lattice. Args: lats: The decoding lattice from the first pass after LM rescoring. lats is an FsaVec. It can be the return value of :func:`whole_lattice_rescoring` num_paths: Size of n for n-best list. CAUTION: After removing paths that represent the same token sequences, the number of paths in different sequences may not be equal. Return: Return an Nbest object. Note the returned FSAs don't have epsilon self-loops. ''' assert len(lats.shape) == 3 # CAUTION: We use `phones` instead of `tokens` here because # :func:`compile_HLG` uses `phones` # # Note: compile_HLG is from k2-fsa/snowfall assert hasattr(lats, 'phones') assert not hasattr(lats, 'tokens') lats.tokens = lats.phones # we use tokens instead of phones in the following code # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # token_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains token IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. # Its axes are [seq][path][token_id] token_seqs = k2.index(lats.tokens, paths) # Remove epsilons (0s) and -1 from token_seqs token_seqs = k2.ragged.remove_values_leq(token_seqs, 0) # unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id]. # But then number of pathsin each sequence may be different. unique_token_seqs, _, _ = k2.ragged.unique_sequences( word_seqs, need_num_repeats=False, need_new2old_indexes=False) seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) # Remove the seq axis. # Now unique_token_seqs has only two axes [path][token_id] unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) token_fsas = k2.linear_fsa(unique_token_seqs) return Nbest(fsa=token_fsas, shape=seq_to_path_shape)
def _intersect_device(a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor, sorted_match_a: bool): '''This is a wrapper of k2.intersect_device and its purpose is to split b_fsas into several batches and process each batch separately to avoid CUDA OOM error. The arguments and return value of this function are the same as k2.intersect_device. ''' # NOTE: You can decrease batch_size in case of CUDA out of memory error. batch_size = 500 num_fsas = b_fsas.shape[0] if num_fsas <= batch_size: return k2.intersect_device(a_fsas, b_fsas, b_to_a_map=b_to_a_map, sorted_match_a=sorted_match_a) num_batches = int(math.ceil(float(num_fsas) / batch_size)) splits = [] for i in range(num_batches): start = i * batch_size end = min(start + batch_size, num_fsas) splits.append((start, end)) ans = [] for start, end in splits: indexes = torch.arange(start, end).to(b_to_a_map) fsas = k2.index(b_fsas, indexes) b_to_a = k2.index(b_to_a_map, indexes) path_lats = k2.intersect_device(a_fsas, fsas, b_to_a_map=b_to_a, sorted_match_a=sorted_match_a) ans.append(path_lats) return k2.cat(ans)
def test_sort_sublist_descending(self): for device in self.devices: src = k2.RaggedInt('[ [3 2] [] [1 5 2]]').to(device) src_clone = src.clone() new2old = k2.ragged.sort_sublist(src, descending=True, need_new2old_indexes=True) sorted_src = k2.RaggedInt('[[3 2] [] [5 2 1]]') expected_new2old = torch.tensor([0, 1, 3, 4, 2], device=device, dtype=torch.int32) assert str(src) == str(sorted_src) assert torch.all(torch.eq(new2old, expected_new2old)) expected_sorted = k2.index(src_clone.values(), new2old) sorted = src.values() assert torch.all(torch.eq(expected_sorted, sorted))
def test(self): for device in self.devices: src = torch.tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.int32, device=device) index_row_splits = torch.tensor([0, 2, 2, 3, 7], dtype=torch.int32, device=device) index_shape = k2.ragged.create_ragged_shape2( index_row_splits, None, 7) index_values = torch.tensor([0, 3, 2, 3, 5, 1, 3], dtype=torch.int32, device=device) ragged_index = k2.RaggedInt(index_shape, index_values) ans = k2.index(src, ragged_index) self.assertTrue(torch.allclose(ans.row_splits(1), index_row_splits)) expected_values = torch.tensor([1, 4, 3, 4, 6, 2, 4], dtype=torch.int32, device=device) self.assertTrue(torch.allclose(ans.values(), expected_values))
def _compute_mmi_loss_exact_optimized( nnet_output: torch.Tensor, texts: List[str], supervision_segments: torch.Tensor, graph_compiler: MmiTrainingGraphCompiler, P: k2.Fsa, den_scale: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ''' The function name contains `exact`, which means it uses a version of intersection without pruning. `optimized` in the function name means this function is optimized in that it calls k2.intersect_dense only once Note: It is faster at the cost of using more memory. Args: nnet_output: A 3-D tensor of shape [N, T, C] texts: The transcript. Each element consists of space(s) separated words. supervision_segments: A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`. graph_compiler: Used to build num_graphs and den_graphs P: Represents a bigram Fsa. den_scale: The scale applied to the denominator tot_scores. ''' num_graphs, den_graphs = graph_compiler.compile(texts, P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) device = num_graphs.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas assert den_graphs.shape[0] == 1 # the aux_labels of num_graphs is k2.RaggedInt # but it is torch.Tensor for den_graphs. # # The following converts den_graphs.aux_labels # from torch.Tensor to k2.RaggedInt so that # we can use k2.append() later den_graphs.convert_attr_to_ragged_(name='aux_labels') # The motivation to concatenate num_graphs and den_graphs # is to reduce the number of calls to k2.intersect_dense. num_den_graphs = k2.cat([num_graphs, den_graphs]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # # The following code computes a_to_b_map # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack( [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense(num_den_reordered_graphs, dense_fsa_vec, output_beam=10.0, a_to_b_map=a_to_b_map) num_den_tot_scores = num_den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] tot_scores = num_tot_scores - den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa], output_beam: float, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor, a_to_b_map: Optional[torch.Tensor] = None, seqframe_idx_name: Optional[str] = None, frame_idx_name: Optional[str] = None) -> torch.Tensor: '''Intersect array of FSAs on CPU/GPU. Args: a_fsas: Input FsaVec, i.e., `decoding graphs`, one per sequence. It might just be a linear sequence of phones, or might be something more complicated. Must have number of FSAs equal to b_fsas.dim0(), if a_to_b_map not specified. b_fsas: Input FSAs that correspond to neural network output. out_fsa: A list containing ONLY one entry which will be set to the generated FSA on return. We pass it as a list since the return value can only be types of torch.Tensor in the `forward` function. output_beam: Pruning beam for the output of intersection (vs. best path); equivalent to kaldi's lattice-beam. E.g. 8. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. unused_scores_b: It equals to `b_fsas.scores` and its sole purpose is for back propagation. a_to_b_map: Maps from FSA-index in a to FSA-index in b to use for it. If None, then we expect the number of FSAs in a_fsas to equal b_fsas.dim0(). If set, then it should be a Tensor with ndim=1 and dtype=torch.int32, with a_to_b_map.shape[0] equal to the number of FSAs in a_fsas (i.e. a_fsas.shape[0] if len(a_fsas.shape) == 3, else 1); and elements 0 <= i < b_fsas.dim0(). seqframe_idx_name: If set (e.g. to 'seqframe'), an attribute in the output will be created that encodes the sequence-index and the frame-index within that sequence; this is equivalent to a row-index into b_fsas.values, or, equivalently, an element in b_fsas.shape. frame_idx_name: If set (e.g. to 'frame', an attribute in the output will be created that contains the frame-index within the corresponding sequence. Returns: Return `out_fsa[0].scores`. ''' assert len(out_fsa) == 1 ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense( a_fsas=a_fsas.arcs, b_fsas=b_fsas.dense_fsa_vec, a_to_b_map=a_to_b_map, output_beam=output_beam) out_fsa[0] = Fsa(ragged_arc) for name, a_value in a_fsas.named_tensor_attr(include_scores=False): value = k2.index(a_value, arc_map_a) setattr(out_fsa[0], name, value) for name, a_value in a_fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b ctx.save_for_backward(unused_scores_a, unused_scores_b) seqframe_idx = None if frame_idx_name is not None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols shape = b_fsas.dense_fsa_vec.shape() fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx) frame_idx = seqframe_idx - _k2.index_select( shape.row_splits(1), fsa_idx0) assert not hasattr(out_fsa[0], frame_idx_name) setattr(out_fsa[0], frame_idx_name, frame_idx) if seqframe_idx_name is not None: if seqframe_idx is None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols assert not hasattr(out_fsa[0], seqframe_idx_name) setattr(out_fsa[0], seqframe_idx_name, seqframe_idx) return out_fsa[0].scores
def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa], search_beam: float, output_beam: float, min_active_states: int, max_active_states: int, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor, seqframe_idx_name: Optional[str] = None, frame_idx_name: Optional[str] = None) -> torch.Tensor: '''Intersect array of FSAs on CPU/GPU. Args: a_fsas: Input FsaVec, i.e., `decoding graphs`, one per sequence. It might just be a linear sequence of phones, or might be something more complicated. Must have either `a_fsas.shape[0] == b_fsas.dim0()`, or `a_fsas.shape[0] == 1` in which case the graph is shared. b_fsas: Input FSAs that correspond to neural network output. out_fsa: A list containing ONLY one entry which will be set to the generated FSA on return. We pass it as a list since the return value can only be types of torch.Tensor in the `forward` function. search_beam: Decoding beam, e.g. 20. Smaller is faster, larger is more exact (less pruning). This is the default value; it may be modified by `min_active_states` and `max_active_states`. output_beam: Pruning beam for the output of intersection (vs. best path); equivalent to kaldi's lattice-beam. E.g. 8. max_active_states: Maximum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. min_active_states: Minimum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to have fewer than this number active. Set it to zero if there is no constraint. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. unused_scores_b: It equals to `b_fsas.scores` and its sole purpose is for back propagation. seqframe_idx_name: If set (e.g. to 'seqframe'), an attribute in the output will be created that encodes the sequence-index and the frame-index within that sequence; this is equivalent to a row-index into b_fsas.values, or, equivalently, an element in b_fsas.shape. frame_idx_name: If set (e.g. to 'frame', an attribute in the output will be created that contains the frame-index within the corresponding sequence. Returns: Return `out_fsa[0].scores`. ''' assert len(out_fsa) == 1 ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense_pruned( a_fsas=a_fsas.arcs, b_fsas=b_fsas.dense_fsa_vec, search_beam=search_beam, output_beam=output_beam, min_active_states=min_active_states, max_active_states=max_active_states) out_fsa[0] = Fsa(ragged_arc) for name, a_value in a_fsas.named_tensor_attr(include_scores=False): value = k2.index(a_value, arc_map_a) setattr(out_fsa[0], name, value) for name, a_value in a_fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b ctx.save_for_backward(unused_scores_a, unused_scores_b) seqframe_idx = None if frame_idx_name is not None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols shape = b_fsas.dense_fsa_vec.shape() fsa_idx0 = _k2.index_select(shape.row_ids(1), seqframe_idx) frame_idx = seqframe_idx - _k2.index_select( shape.row_splits(1), fsa_idx0) assert not hasattr(out_fsa[0], frame_idx_name) setattr(out_fsa[0], frame_idx_name, frame_idx) if seqframe_idx_name is not None: if seqframe_idx is None: num_cols = b_fsas.dense_fsa_vec.scores_dim1() seqframe_idx = arc_map_b // num_cols assert not hasattr(out_fsa[0], seqframe_idx_name) setattr(out_fsa[0], seqframe_idx_name, seqframe_idx) return out_fsa[0].scores
def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, num_paths: int) -> k2.Fsa: '''Decode using n-best list with LM rescoring. `lats` is a decoding lattice, which has 3 axes. This function first extracts `num_paths` paths from `lats` for each sequence using `k2.random_paths`. The `am_scores` of these paths are computed. For each path, its `lm_scores` is computed using `G` (which is an LM). The final `tot_scores` is the sum of `am_scores` and `lm_scores`. The path with the greatest `tot_scores` within a sequence is used as the decoding output. Args: lats: An FsaVec. It can be the output of `k2.intersect_dense_pruned`. G: An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. num_paths: It is the size `n` in `n-best` list. Returns: An FsaVec representing the best decoding path for each sequence in the lattice. ''' device = lats.device assert len(lats.shape) == 3 assert hasattr(lats, 'aux_labels') assert hasattr(lats, 'lm_scores') assert G.shape == (1, None, None) assert G.device == device assert hasattr(G, 'aux_labels') is False # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. word_seqs = k2.index(lats.aux_labels, paths) # Remove epsilons and -1 from word_seqs word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. # # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word] # except that there are no repeated paths with the same word_seq # within a seq. # # num_repeats is also a k2.RaggedInt with 2 axes containing the # multiplicities of each path. # num_repeats.num_elements() == unique_word_seqs.num_elements() # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=True, need_new2old_indexes=True) seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, path_to_seq_map) # Now compute lm_scores b_to_a_map = torch.zeros_like(path_to_seq_map) lm_path_lats = _intersect_device(G, word_fsas_with_epsilon_loops, b_to_a_map=b_to_a_map, sorted_match_a=True) lm_path_lats = k2.top_sort(k2.connect(lm_path_lats.to('cpu'))).to(device) lm_scores = lm_path_lats.get_tot_scores(True, True) tot_scores = am_scores + lm_scores # Remember that we used `k2.ragged.unique_sequences` to remove repeated # paths to avoid redundant computation in `k2.intersect_device`. # Now we use `num_repeats` to correct the scores for each path. # # NOTE(fangjun): It is commented out as it leads to a worse WER # tot_scores = tot_scores * num_repeats.values() # TODO(fangjun): We may need to add `k2.RaggedDouble` ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) # Use k2.index here since argmax_indexes' dtype is torch.int32 best_path_indexes = k2.index(new2old, argmax_indexes) paths = k2.ragged.remove_axis(paths, 0) # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] best_paths = k2.index(paths, best_path_indexes) # labels is a k2.RaggedInt with 2 axes [path][phone_id] # Note that it contains -1s. labels = k2.index(lats.labels.contiguous(), best_paths) labels = k2.ragged.remove_values_eq(labels, -1) # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so # aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index(lats.aux_labels, best_paths.values()) best_path_fsas = k2.linear_fsa(labels) best_path_fsas.aux_labels = aux_labels return best_path_fsas
def nbest_decoding(lats: k2.Fsa, num_paths: int): ''' (Ideas of this function are from Dan) It implements something like CTC prefix beam search using n-best lists The basic idea is to first extra n-best paths from the given lattice, build a word seqs from these paths, and compute the total scores of these sequences in the log-semiring. The one with the max score is used as the decoding output. ''' # First, extract `num_paths` paths for each sequence. # paths is a k2.RaggedInt with axes [seq][path][arc_pos] paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) # word_seqs is a k2.RaggedInt sharing the same shape as `paths` # but it contains word IDs. Note that it also contains 0s and -1s. # The last entry in each sublist is -1. word_seqs = k2.index(lats.aux_labels, paths) # Note: the above operation supports also the case when # lats.aux_labels is a ragged tensor. In that case, # `remove_axis=True` is used inside the pybind11 binding code, # so the resulting `word_seqs` still has 3 axes, like `paths`. # The 3 axes are [seq][path][word] # Remove epsilons and -1 from word_seqs word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) # Remove repeated sequences to avoid redundant computation later. # # Since k2.ragged.unique_sequences will reorder paths within a seq, # `new2old` is a 1-D torch.Tensor mapping from the output path index # to the input path index. # new2old.numel() == unique_word_seqs.num_elements() unique_word_seqs, _, new2old = k2.ragged.unique_sequences( word_seqs, need_num_repeats=False, need_new2old_indexes=True) # Note: unique_word_seqs still has the same axes as word_seqs seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) # path_to_seq_map is a 1-D torch.Tensor. # path_to_seq_map[i] is the seq to which the i-th path # belongs. path_to_seq_map = seq_to_path_shape.row_ids(1) # Remove the seq axis. # Now unique_word_seqs has only two axes [path][word] unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) # word_fsas is an FsaVec with axes [path][state][arc] word_fsas = k2.linear_fsa(unique_word_seqs) word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) # lats has phone IDs as labels and word IDs as aux_labels. # inv_lats has word IDs as labels and phone IDs as aux_labels inv_lats = k2.invert(lats) inv_lats = k2.arc_sort(inv_lats) # no-op if inv_lats is already arc-sorted path_lats = k2.intersect_device(inv_lats, word_fsas_with_epsilon_loops, b_to_a_map=path_to_seq_map, sorted_match_a=True) # path_lats has word IDs as labels and phone IDs as aux_labels path_lats = k2.top_sort(k2.connect(path_lats.to('cpu')).to(lats.device)) tot_scores = path_lats.get_tot_scores(True, True) # RaggedFloat currently supports float32 only. # We may bind Ragged<double> as RaggedDouble if needed. ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, tot_scores.to(torch.float32)) argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) # Since we invoked `k2.ragged.unique_sequences`, which reorders # the index from `paths`, we use `new2old` # here to convert argmax_indexes to the indexes into `paths`. # # Use k2.index here since argmax_indexes' dtype is torch.int32 best_path_indexes = k2.index(new2old, argmax_indexes) paths_2axes = k2.ragged.remove_axis(paths, 0) # best_paths is a k2.RaggedInt with 2 axes [path][arc_pos] best_paths = k2.index(paths_2axes, best_path_indexes) # labels is a k2.RaggedInt with 2 axes [path][phone_id] # Note that it contains -1s. labels = k2.index(lats.labels.contiguous(), best_paths) labels = k2.ragged.remove_values_eq(labels, -1) # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so # aux_labels is also a k2.RaggedInt with 2 axes aux_labels = k2.index(lats.aux_labels, best_paths.values()) best_path_fsas = k2.linear_fsa(labels) best_path_fsas.aux_labels = aux_labels return best_path_fsas
def forward( self, nnet_output: torch.Tensor, texts: List, supervision_segments: torch.Tensor ) -> Tuple[torch.Tensor, int, int]: num_graphs, den_graphs = self.graph_compiler.compile( texts, self.P, replicate_den=False) dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) device = num_graphs.device num_fsas = num_graphs.shape[0] assert dense_fsa_vec.dim0() == num_fsas assert den_graphs.shape[0] == 1 # the aux_labels of num_graphs is k2.RaggedInt # but it is torch.Tensor for den_graphs. # # The following converts den_graphs.aux_labels # from torch.Tensor to k2.RaggedInt so that # we can use k2.append() later den_graphs.convert_attr_to_ragged_(name='aux_labels') num_den_graphs = k2.cat([num_graphs, den_graphs]) # NOTE: The a_to_b_map in k2.intersect_dense must be sorted # so the following reorders num_den_graphs. # [0, 1, 2, ... ] num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) # [num_fsas, num_fsas, num_fsas, ... ] den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32) # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = torch.stack( [num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) # [[0, 1, 2, ...]] a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) num_den_lats = k2.intersect_dense(num_den_reordered_graphs, dense_fsa_vec, output_beam=10.0, a_to_b_map=a_to_b_map) num_den_tot_scores = num_den_lats.get_tot_scores( log_semiring=True, use_double_scores=True) num_tot_scores = num_den_tot_scores[::2] den_tot_scores = num_den_tot_scores[1::2] tot_scores = num_tot_scores - self.den_scale * den_tot_scores tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( tot_scores, supervision_segments[:, 2]) return tot_score, tot_frames, all_frames
def forward(ctx, a_fsas: Fsa, b_fsas: DenseFsaVec, out_fsa: List[Fsa], output_beam: float, unused_scores_a: torch.Tensor, unused_scores_b: torch.Tensor) -> torch.Tensor: '''Intersect array of FSAs on CPU/GPU. Args: a_fsas: Input FsaVec, i.e., `decoding graphs`, one per sequence. It might just be a linear sequence of phones, or might be something more complicated. Must have `a_fsas.shape[0] == b_fsas.dim0()`. b_fsas: Input FSAs that correspond to neural network output. out_fsa: A list containing ONLY one entry which will be set to the generated FSA on return. We pass it as a list since the return value can only be types of torch.Tensor in the `forward` function. search_beam: Decoding beam, e.g. 20. Smaller is faster, larger is more exact (less pruning). This is the default value; it may be modified by `min_active_states` and `max_active_states`. output_beam: Pruning beam for the output of intersection (vs. best path); equivalent to kaldi's lattice-beam. E.g. 8. max_active_states: Maximum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. You can use a very large number if no constraint is needed. min_active_states: Minimum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, in that it will try not to have fewer than this number active. Set it to zero if there is no constraint. unused_scores_a: It equals to `a_fsas.scores` and its sole purpose is for back propagation. unused_scores_b: It equals to `b_fsas.scores` and its sole purpose is for back propagation. Returns: Return `out_fsa[0].scores`. ''' assert len(out_fsa) == 1 ragged_arc, arc_map_a, arc_map_b = _k2.intersect_dense( a_fsas=a_fsas.arcs, b_fsas=b_fsas.dense_fsa_vec, output_beam=output_beam) out_fsa[0] = Fsa(ragged_arc) for name, a_value in a_fsas.named_tensor_attr(include_scores=False): value = k2.index(a_value, arc_map_a) setattr(out_fsa[0], name, value) for name, a_value in a_fsas.named_non_tensor_attr(): setattr(out_fsa[0], name, a_value) ctx.arc_map_a = arc_map_a ctx.arc_map_b = arc_map_b ctx.save_for_backward(unused_scores_a, unused_scores_b) return out_fsa[0].scores