def create_fsa_vec(fsas): '''Create an FsaVec from a list of FSAs We use the following rules to set the attributes of the output FsaVec: - For tensor attributes, we assume that all input FSAs have the same attribute name and the values are concatenated. - For non-tensor attributes, if any two of the input FSAs have the same attribute name, then we assume that their attribute values are equal and the output FSA will inherit the attribute. Args: fsas: A list of `Fsa`. Each element must be a single FSA. Returns: An instance of :class:`Fsa` that represents a FsaVec. ''' ragged_arc_list = list() for fsa in fsas: assert len(fsa.shape) == 2 ragged_arc_list.append(fsa.arcs) ragged_arcs = _k2.create_fsa_vec(ragged_arc_list) fsa_vec = Fsa(ragged_arcs) tensor_attr_names = set( name for name, _ in fsa.named_tensor_attr() for fsa in fsas) for name in tensor_attr_names: values = [] for fsa in fsas: values.append(getattr(fsa, name)) if isinstance(values[0], torch.Tensor): value = torch.cat(values) else: assert isinstance(values[0], k2.RaggedTensor) value = k2.ragged.cat(values, axis=0) setattr(fsa_vec, name, value) non_tensor_attr_names = set() for fsa in fsas: for name, _ in fsa.named_non_tensor_attr(): non_tensor_attr_names.add(name) for name in non_tensor_attr_names: if name == 'properties': continue for fsa in fsas: value = getattr(fsa, name, None) if value is not None: if hasattr(fsa_vec, name): assert getattr(fsa_vec, name) == value else: setattr(fsa_vec, name, value) return fsa_vec
def test_compose(self): s = ''' 0 1 11 1 1.0 0 2 12 2 2.5 1 3 -1 -1 0 2 3 -1 -1 2.5 3 ''' a_fsa = k2.Fsa.from_str(s).requires_grad_(True) s = ''' 0 1 1 1 1.0 0 2 2 3 3.0 1 2 3 2 2.5 2 3 -1 -1 2.0 3 ''' b_fsa = k2.Fsa.from_str(s).requires_grad_(True) ans = k2.compose(a_fsa, b_fsa, inner_labels='inner') ans = k2.connect(ans) # Convert a single FSA to a FsaVec. # It will retain `requires_grad_` of `ans`. ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs]) scores = k2.get_tot_scores(ans, log_semiring=True, use_double_scores=False) # The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad` # are computed using GTN. # See https://bit.ly/3heLAJq assert scores.item() == 10 scores.backward() assert torch.allclose(a_fsa.grad, torch.tensor([0., 1., 0., 1.])) assert torch.allclose(b_fsa.grad, torch.tensor([0., 1., 0., 1.])) print(ans)