def expand_ragged_attributes( fsas: Fsa, ret_arc_map: bool = False, ragged_attribute_names: Optional[List[str]] = None ) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]: # noqa ''' Turn ragged labels attached to this FSA into linear (Tensor) labels, expanding arcs into sequences of arcs as necessary to achieve this. Supports autograd. If `fsas` had no ragged attributes, returns `fsas` itself. Caution: this function will ensure that for final-arcs in the returned fsa, the corresponding labels for all ragged attributes are -1; it will add an extra arc at the end is necessary to ensure this, if the original ragged attributes did not have -1 as their final element on final-arcs (note: our intention is that -1's on final arcs, like filler symbols, are removed when making attributes ragged; this is what fsa_from_unary_function_ragged() does if remove_filler==True (the default). fsas: The source Fsa ret_arc_map: if true, will return a pair (new_fsas, arc_map) with `arc_map` a tensor of int32 that maps from arcs in the result to arcs in `fsas`, with -1's for newly created arcs. If false, just returns new_fsas. ragged_attribute_names: If specified, just this list of ragged attributes will be expanded to linear tensor attributes, and the rest will stay ragged. ''' if ragged_attribute_names is None: ragged_attribute_tensors = [] ragged_attribute_names = [] for name, value in fsas.named_tensor_attr(include_scores=False): if isinstance(value, k2.RaggedInt): ragged_attribute_tensors.append(value) ragged_attribute_names.append(name) else: ragged_attribute_tensors = [ getattr(fsas, name) for name in ragged_attribute_names ] for t in ragged_attribute_tensors: assert isinstance(t, k2.RaggedInt) if len(ragged_attribute_tensors) == 0: if ret_arc_map: arc_map = torch.arange(fsas.num_arcs, dtype=torch.int32, device=fsas.device) return (fsas, arc_map) else: return fsas (dest_arcs, dest_labels, arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors) # The rest of this function is a modified version of # `fsa_from_unary_function_tensor()`. dest = Fsa(dest_arcs) # Handle the non-ragged attributes, and ragged attributes that # we're not linearizing. for name, value in fsas.named_tensor_attr(include_scores=False): if isinstance(value, torch.Tensor): filler = float(fsas.get_filler(name)) setattr(dest, name, index_select(value, arc_map, default_value=filler)) elif name not in ragged_attribute_names: setattr(dest, name, index(value, arc_map)) # Handle the attributes that were ragged but are now linear for name, value in zip(ragged_attribute_names, dest_labels): setattr(dest, name, value) # Copy non-tensor attributes for name, value in fsas.named_non_tensor_attr(): setattr(dest, name, value) # make sure autograd works on the scores k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map) # Make sure -1's are only on final-arcs, and never on non-final arcs. if hasattr(dest, 'aux_labels'): _k2.fix_final_labels(dest.arcs, dest.aux_labels) if ret_arc_map: return dest, arc_map else: return dest
def test_final(self): devices = [torch.device('cpu')] if torch.cuda.is_available() and k2.with_cuda: devices.append(torch.device('cuda', 0)) for device in devices: for need_map in [True, False]: s = ''' 0 1 2 10 0 1 1 20 1 2 -1 30 2 ''' src = k2.Fsa.from_str(s).to(device).requires_grad_(True) src.float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedInt('[[1 2 3] [5 6] [1]]').to( device) src.attr1 = 'src' src.attr2 = 'fsa' if need_map: dest, arc_map = k2.expand_ragged_attributes( src, ret_arc_map=True) else: dest = k2.expand_ragged_attributes(src) assert torch.allclose( dest.float_attr, torch.tensor([0.1, 0.2, 0.0, 0.0, 0.0, 0.3, 0.0], dtype=torch.float32, device=device)) assert torch.all( torch.eq( dest.scores, torch.tensor([10, 20, 0, 0, 0, 30, 0], dtype=torch.float32, device=device))) assert torch.all( torch.eq( dest.int_attr, torch.tensor([1, 2, 0, 0, 0, 3, 0], dtype=torch.int32, device=device))) _k2.fix_final_labels(dest.arcs, dest.int_attr) assert torch.all( torch.eq( dest.int_attr, torch.tensor([1, 2, 0, 0, 0, 3, -1], dtype=torch.int32, device=device))) assert torch.all( torch.eq( dest.ragged_attr, torch.tensor([1, 5, 2, 3, 6, 1, -1], dtype=torch.float32, device=device))) # non-tensor attributes... assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 # now for autograd scale = torch.tensor([10, 20, 10, 10, 10, 30, 10], device=device) (dest.float_attr * scale).sum().backward() (dest.scores * scale).sum().backward() expected_grad = torch.tensor([10, 20, 30], dtype=torch.float32, device=device) assert torch.all(torch.eq(src.float_attr.grad, expected_grad)) assert torch.all(torch.eq(src.scores.grad, expected_grad))