def expand_ragged_attributes( fsas: Fsa, ret_arc_map: bool = False ) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]: # noqa ''' Turn ragged labels attached to this FSA into linear (Tensor) labels, expanding arcs into sequences of arcs as necessary to achieve this. Supports autograd. If `fsas` had no ragged attributes, returns `fsas` itself. ret_arc_map: if true, will return a pair (new_fsas, arc_map) with `arc_map` a tensor of int32 that maps from arcs in the result to arcs in `fsas`, with -1's for newly created arcs. If false, just returns new_fsas. ''' ragged_attribute_tensors = [] ragged_attribute_names = [] for name, value in fsas.named_tensor_attr(include_scores=False): if isinstance(value, k2.RaggedInt): ragged_attribute_tensors.append(value) ragged_attribute_names.append(name) if len(ragged_attribute_tensors) == 0: if ret_arc_map: arc_map = torch.arange(fsas.num_arcs, dtype=torch.int32, device=fsas.device) return (fsas, arc_map) else: return fsas (dest_arcs, dest_labels, arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors) # The rest of this function is a modified version of # `fsa_from_unary_function_tensor()`. dest = Fsa(dest_arcs) # Handle the non-ragged attributes for name, value in fsas.named_tensor_attr(include_scores=False): if not isinstance(value, k2.RaggedInt): setattr(dest, name, k2.index(value, arc_map)) # Handle the attributes that were ragged but are now linear for name, value in zip(ragged_attribute_names, dest_labels): setattr(dest, name, value) # Copy non-tensor attributes for name, value in fsas.named_non_tensor_attr(): setattr(dest, name, value) # make sure autograd works on the scores k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map) if ret_arc_map: return dest, arc_map else: return dest
def expand_ragged_attributes( fsas: Fsa, ret_arc_map: bool = False, ragged_attribute_names: Optional[List[str]] = None ) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]: # noqa ''' Turn ragged labels attached to this FSA into linear (Tensor) labels, expanding arcs into sequences of arcs as necessary to achieve this. Supports autograd. If `fsas` had no ragged attributes, returns `fsas` itself. Caution: this function will ensure that for final-arcs in the returned fsa, the corresponding labels for all ragged attributes are -1; it will add an extra arc at the end is necessary to ensure this, if the original ragged attributes did not have -1 as their final element on final-arcs (note: our intention is that -1's on final arcs, like filler symbols, are removed when making attributes ragged; this is what fsa_from_unary_function_ragged() does if remove_filler==True (the default). fsas: The source Fsa ret_arc_map: if true, will return a pair (new_fsas, arc_map) with `arc_map` a tensor of int32 that maps from arcs in the result to arcs in `fsas`, with -1's for newly created arcs. If false, just returns new_fsas. ragged_attribute_names: If specified, just this list of ragged attributes will be expanded to linear tensor attributes, and the rest will stay ragged. ''' if ragged_attribute_names is None: ragged_attribute_tensors = [] ragged_attribute_names = [] for name, value in fsas.named_tensor_attr(include_scores=False): if isinstance(value, k2.RaggedInt): ragged_attribute_tensors.append(value) ragged_attribute_names.append(name) else: ragged_attribute_tensors = [ getattr(fsas, name) for name in ragged_attribute_names ] for t in ragged_attribute_tensors: assert isinstance(t, k2.RaggedInt) if len(ragged_attribute_tensors) == 0: if ret_arc_map: arc_map = torch.arange(fsas.num_arcs, dtype=torch.int32, device=fsas.device) return (fsas, arc_map) else: return fsas (dest_arcs, dest_labels, arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors) # The rest of this function is a modified version of # `fsa_from_unary_function_tensor()`. dest = Fsa(dest_arcs) # Handle the non-ragged attributes, and ragged attributes that # we're not linearizing. for name, value in fsas.named_tensor_attr(include_scores=False): if isinstance(value, torch.Tensor): filler = float(fsas.get_filler(name)) setattr(dest, name, index_select(value, arc_map, default_value=filler)) elif name not in ragged_attribute_names: setattr(dest, name, index(value, arc_map)) # Handle the attributes that were ragged but are now linear for name, value in zip(ragged_attribute_names, dest_labels): setattr(dest, name, value) # Copy non-tensor attributes for name, value in fsas.named_non_tensor_attr(): setattr(dest, name, value) # make sure autograd works on the scores k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map) # Make sure -1's are only on final-arcs, and never on non-final arcs. if hasattr(dest, 'aux_labels'): _k2.fix_final_labels(dest.arcs, dest.aux_labels) if ret_arc_map: return dest, arc_map else: return dest