예제 #1
0
파일: fsa_algo.py 프로젝트: entn-at/k2
def intersect_device(
    a_fsas: Fsa,
    b_fsas: Fsa,
    b_to_a_map: torch.Tensor,
    sorted_match_a: bool = False,
    ret_arc_maps: bool = False
) -> Union[Fsa, Tuple[Fsa, torch.Tensor, torch.Tensor]]:  # noqa
    '''Compute the intersection of two FsaVecs treating epsilons
    as real, normal symbols.

    This function supports both CPU and GPU. But it is very slow on CPU.
    That's why this function name ends with `_device`. It is intended for GPU.
    See :func:`k2.intersect` which is a more general interface
    (it will call the same underlying code, IntersectDevice(), if
    the inputs are on GPU and a_fsas is arc-sorted).

    Caution:
      Epsilons are treated as real, normal symbols.

    Hint:
      The two inputs do not need to be arc-sorted.

    Refer to :func:`k2.intersect` for how we assign the attributes of the
    output FsaVec.

    Args:
      a_fsas:
        An FsaVec (must have 3 axes, i.e., `len(a_fsas.shape) == 3`.
      b_fsas:
        An FsaVec (must have 3 axes) on the same device as `a_fsas`.
      b_to_a_map:
        A 1-D torch.Tensor with dtype torch.int32 on the same device
        as `a_fsas`. Map from FSA-id in `b_fsas` to the corresponding
        FSA-id in `a_fsas` that we want to compose it with.
        E.g. might be an identity map, or all-to-zero, or something the
        user chooses.

        Requires
            - `b_to_a_map.shape[0] == b_fsas.shape[0]`
            - `0 <= b_to_a_map[i] < a_fsas.shape[0]`
      ret_arc_maps:
        If False, return the resulting Fsa. If True, return a tuple
        containing three entries:

            - the resulting Fsa

            - a_arc_map, a 1-D torch.Tensor with dtype torch.int32.
              a_arc_map[i] is the arc index in a_fsas that corresponds
              to the i-th arc in the resulting Fsa. a_arc_map[i] is -1
              if the i-th arc in the resulting Fsa has no corresponding
              arc in a_fsas.

            - b_arc_map, a 1-D torch.Tensor with dtype torch.int32.
              b_arc_map[i] is the arc index in b_fsas that corresponds
              to the i-th arc in the resulting Fsa. b_arc_map[i] is -1
              if the i-th arc in the resulting Fsa has no corresponding
              arc in b_fsas.

    Returns:
      If ret_arc_maps is False, return intersected FsaVec;
      will satisfy `ans.shape == b_fsas.shape`.
      If ret_arc_maps is True, it returns additionally two arc maps:
      a_arc_map and b_arc_map.
    '''
    need_arc_map = True
    ragged_arc, a_arc_map, b_arc_map = _k2.intersect_device(
        a_fsas.arcs, a_fsas.properties, b_fsas.arcs, b_fsas.properties,
        b_to_a_map, need_arc_map, sorted_match_a)
    out_fsas = Fsa(ragged_arc)

    _propagate_aux_labels_binary_function(a_fsas, b_fsas, a_arc_map, b_arc_map,
                                          out_fsas)

    if ret_arc_maps:
        return out_fsas, a_arc_map, b_arc_map
    else:
        return out_fsas
예제 #2
0
파일: fsa_algo.py 프로젝트: yyht/k2
def intersect_device(a_fsas: Fsa, b_fsas: Fsa,
                     b_to_a_map: torch.Tensor) -> Fsa:
    '''Compute the intersection of two FSAs treating epsilons
    as real, normal symbols.

    This function supports both CPU and GPU. But it is very slow on CPU.
    That's why this function name ends with `_device`. It is intended for GPU.
    See :func:`k2.intersect` for intersecting two FSAs on CPU.

    Caution:
      Epsilons are treated as real, normal symbols.

    Hint:
      The two inputs do not need to be arc-sorted.

    Refer to :func:`k2.intersect` for how we assign the attributes of the
    output FsaVec.

    Args:
      a_fsas:
        An FsaVec (must have 3 axes, i.e., `len(a_fsas.shape) == 3`.
      b_fsas:
        An FsaVec (must have 3 axes) on the same device as `a_fsas`.
      b_to_a_map:
        A 1-D torch.Tensor with dtype torch.int32 on the same device
        as `a_fsas`. Map from FSA-id in `b_fsas` to the corresponding
        FSA-id in `a_fsas` that we want to compose it with.
        E.g. might be an identity map, or all-to-zero, or something the
        user chooses.

        Requires
            - `b_to_a_map.shape[0] == b_fsas.shape[0]`
            - `0 <= b_to_a_map[i] < a_fsas.shape[0]`

    Returns:
      Returns composed FsaVec; will satisfy `ans.shape == b_fsas.shape`.
    '''
    need_arc_map = True
    ragged_arc, a_arc_map, b_arc_map = _k2.intersect_device(
        a_fsas.arcs, a_fsas.properties, b_fsas.arcs, b_fsas.properties,
        b_to_a_map, need_arc_map)
    out_fsas = Fsa(ragged_arc)

    for name, a_value in a_fsas.named_tensor_attr():
        if hasattr(b_fsas, name):
            # Both a_fsas and b_fsas have this attribute.
            # We only support attributes with dtype `torch.float32`.
            # Other kinds of attributes are discarded.
            if a_value.dtype != torch.float32:
                continue
            b_value = getattr(b_fsas, name)
            assert b_value.dtype == torch.float32

            value = index_select(a_value, a_arc_map) \
                    + index_select(b_value, b_arc_map)
            setattr(out_fsas, name, value)
        else:
            # only a_fsas has this attribute, copy it via arc_map
            value = index(a_value, a_arc_map)
            setattr(out_fsas, name, value)

    # now copy tensor attributes that are in b_fsas but are not in a_fsas
    for name, b_value in b_fsas.named_tensor_attr():
        if not hasattr(out_fsas, name):
            value = index(b_value, b_arc_map)
            setattr(out_fsas, name, value)

    for name, a_value in a_fsas.named_non_tensor_attr():
        setattr(out_fsas, name, a_value)

    for name, b_value in b_fsas.named_non_tensor_attr():
        if not hasattr(out_fsas, name):
            setattr(out_fsas, name, b_value)

    return out_fsas