示例#1
0
文件: ops.py 项目: entn-at/k2
def cat(srcs: List[Fsa]) -> Fsa:
    '''Concatenate a list of FsaVec into a single FsaVec.

    CAUTION:
      Only common tensor attributes are kept in the output FsaVec.
      For non-tensor attributes, only one copy is kept in the output
      FsaVec. We choose the first copy of the FsaVec that has the
      lowest index in `srcs`.

    Args:
      srcs:
        A list of FsaVec. Each element MUST be an FsaVec.
    Returns:
      Return a single FsaVec concatenated from the input FsaVecs.
    '''
    for src in srcs:
        assert len(src.shape) == 3, f'Expect an FsaVec. Given: {src.shape}'

    src_ragged_arcs = [fsa.arcs for fsa in srcs]

    ans_ragged_arcs = _k2.cat(src_ragged_arcs, axis=0)
    out_fsa = Fsa(ans_ragged_arcs)

    common_tensor_attributes = (set(dict(src.named_tensor_attr()).keys())
                                for src in srcs)

    common_tensor_attributes = set.intersection(
        *list(common_tensor_attributes))

    for name in common_tensor_attributes:
        # We assume that the type of the attributes among
        # FsaVecs are the same if they share the same name.
        values = [getattr(src, name) for src in srcs]
        if isinstance(values[0], torch.Tensor):
            # NOTE: We assume the shape of elements in values
            # differ only in shape[0].
            value = torch.cat(values)
        else:
            assert isinstance(values[0], k2.RaggedInt)
            value = _k2.cat(values, axis=0)
        setattr(out_fsa, name, value)

    for src in srcs:
        for name, value in src.named_non_tensor_attr():
            if not hasattr(out_fsa, name):
                setattr(out_fsa, name, value)

    return out_fsa
示例#2
0
文件: ops.py 项目: xjohnxjohn/k2
def cat(srcs: List[_k2.RaggedInt], axis=0) -> _k2.RaggedInt:
    '''Concatenate a list of :class:`_k2.RaggedInt` along a given axis.

    Args:
      srcs:
        The input.
      axis:
        It can be either 0 or 1.
    Returns:
      A single ragged tensor.
    '''
    assert axis in (0, 1)
    return _k2.cat(srcs, axis)