def cat(srcs: List[Fsa]) -> Fsa: '''Concatenate a list of FsaVec into a single FsaVec. CAUTION: Only common tensor attributes are kept in the output FsaVec. For non-tensor attributes, only one copy is kept in the output FsaVec. We choose the first copy of the FsaVec that has the lowest index in `srcs`. Args: srcs: A list of FsaVec. Each element MUST be an FsaVec. Returns: Return a single FsaVec concatenated from the input FsaVecs. ''' for src in srcs: assert len(src.shape) == 3, f'Expect an FsaVec. Given: {src.shape}' src_ragged_arcs = [fsa.arcs for fsa in srcs] ans_ragged_arcs = _k2.cat(src_ragged_arcs, axis=0) out_fsa = Fsa(ans_ragged_arcs) common_tensor_attributes = (set(dict(src.named_tensor_attr()).keys()) for src in srcs) common_tensor_attributes = set.intersection( *list(common_tensor_attributes)) for name in common_tensor_attributes: # We assume that the type of the attributes among # FsaVecs are the same if they share the same name. values = [getattr(src, name) for src in srcs] if isinstance(values[0], torch.Tensor): # NOTE: We assume the shape of elements in values # differ only in shape[0]. value = torch.cat(values) else: assert isinstance(values[0], k2.RaggedInt) value = _k2.cat(values, axis=0) setattr(out_fsa, name, value) for src in srcs: for name, value in src.named_non_tensor_attr(): if not hasattr(out_fsa, name): setattr(out_fsa, name, value) return out_fsa
def cat(srcs: List[_k2.RaggedInt], axis=0) -> _k2.RaggedInt: '''Concatenate a list of :class:`_k2.RaggedInt` along a given axis. Args: srcs: The input. axis: It can be either 0 or 1. Returns: A single ragged tensor. ''' assert axis in (0, 1) return _k2.cat(srcs, axis)