def get_texts(best_paths: k2.Fsa) -> List[List[int]]: """Extract the texts from the best-path FSAs. Args: best_paths: a k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't be meaningful). Must have the 'aux_labels' attribute, as a ragged tensor. Return: Returns a list of lists of int, containing the label sequences we decoded. """ # remove any 0's or -1's (there should be no 0's left but may be -1's.) if isinstance(best_paths.aux_labels, k2.RaggedInt): aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) aux_shape = k2r.compose_ragged_shapes(best_paths.arcs.shape(), aux_labels.shape()) # remove the states and arcs axes. aux_shape = k2r.remove_axis(aux_shape, 1) aux_shape = k2r.remove_axis(aux_shape, 1) aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) else: # remove axis corresponding to states. aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) # remove 0's and -1's. aux_labels = k2r.remove_values_leq(aux_labels, 0) assert aux_labels.num_axes() == 2 return k2r.to_list(aux_labels)
def get_texts(best_paths: k2.Fsa, indices: Optional[torch.Tensor] = None) -> List[List[int]]: """Extract the texts from the best-path FSAs. In the original order (before the permutation given by `indices`). Args: best_paths: a k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't be meaningful). Must have the 'aux_labels' attribute, as a ragged tensor. indices: possibly a torch.Tensor giving the permutation that we used on the supervisions of this minibatch to put them in decreasing order of num-frames. We'll apply the inverse permutation. Doesn't have to be on the same device as `best_paths` Return: Returns a list of lists of int, containing the label sequences we decoded. """ # remove any 0's or -1's (there should be no 0's left but may be -1's.) if isinstance(best_paths.aux_labels, k2.RaggedInt): aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) aux_shape = k2r.compose_ragged_shapes(best_paths.arcs.shape(), aux_labels.shape()) # remove the states and arcs axes. aux_shape = k2r.remove_axis(aux_shape, 1) aux_shape = k2r.remove_axis(aux_shape, 1) aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) else: # remove axis corresponding to states. aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) # remove 0's and -1's. aux_labels = k2r.remove_values_leq(aux_labels, 0) assert aux_labels.num_axes() == 2 aux_labels, _ = k2r.index( aux_labels, invert_permutation(indices).to(dtype=torch.int32, device=best_paths.device), ) return k2r.to_list(aux_labels)