示例#1
0
文件: fsa_algo.py 项目: pkufool/k2
def expand_ragged_attributes(
        fsas: Fsa,
        ret_arc_map: bool = False
) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]:  # noqa
    '''
    Turn ragged labels attached to this FSA into linear (Tensor) labels,
    expanding arcs into sequences of arcs as necessary to achieve this.
    Supports autograd.  If `fsas` had no ragged attributes, returns `fsas`
    itself.

         ret_arc_map:  if true, will return a pair (new_fsas, arc_map)
              with `arc_map` a tensor of int32 that maps from arcs in the
              result to arcs in `fsas`, with -1's for newly created arcs.
              If false, just returns new_fsas.
    '''
    ragged_attribute_tensors = []
    ragged_attribute_names = []
    for name, value in fsas.named_tensor_attr(include_scores=False):
        if isinstance(value, k2.RaggedInt):
            ragged_attribute_tensors.append(value)
            ragged_attribute_names.append(name)

    if len(ragged_attribute_tensors) == 0:
        if ret_arc_map:
            arc_map = torch.arange(fsas.num_arcs,
                                   dtype=torch.int32,
                                   device=fsas.device)
            return (fsas, arc_map)
        else:
            return fsas

    (dest_arcs, dest_labels,
     arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors)

    # The rest of this function is a modified version of
    # `fsa_from_unary_function_tensor()`.
    dest = Fsa(dest_arcs)

    # Handle the non-ragged attributes
    for name, value in fsas.named_tensor_attr(include_scores=False):
        if not isinstance(value, k2.RaggedInt):
            setattr(dest, name, k2.index(value, arc_map))

    # Handle the attributes that were ragged but are now linear
    for name, value in zip(ragged_attribute_names, dest_labels):
        setattr(dest, name, value)

    # Copy non-tensor attributes
    for name, value in fsas.named_non_tensor_attr():
        setattr(dest, name, value)

    # make sure autograd works on the scores
    k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map)

    if ret_arc_map:
        return dest, arc_map
    else:
        return dest
示例#2
0
文件: fsa_algo.py 项目: entn-at/k2
def expand_ragged_attributes(
    fsas: Fsa,
    ret_arc_map: bool = False,
    ragged_attribute_names: Optional[List[str]] = None
) -> Union[Fsa, Tuple[Fsa, torch.Tensor]]:  # noqa
    '''
    Turn ragged labels attached to this FSA into linear (Tensor) labels,
    expanding arcs into sequences of arcs as necessary to achieve this.
    Supports autograd.  If `fsas` had no ragged attributes, returns `fsas`
    itself.

    Caution: this function will ensure that for final-arcs in the returned
    fsa, the corresponding labels for all ragged attributes are -1; it will
    add an extra arc at the end is necessary to ensure this, if the
    original ragged attributes did not have -1 as their final element on
    final-arcs (note: our intention is that -1's on final arcs, like filler
    symbols, are removed when making attributes ragged; this is what
    fsa_from_unary_function_ragged() does if remove_filler==True (the
    default).

         fsas:   The source Fsa
         ret_arc_map:  if true, will return a pair (new_fsas, arc_map)
              with `arc_map` a tensor of int32 that maps from arcs in the
              result to arcs in `fsas`, with -1's for newly created arcs.
              If false, just returns new_fsas.
         ragged_attribute_names:  If specified, just this list of ragged
              attributes will be expanded to linear tensor attributes, and
              the rest will stay ragged.
    '''
    if ragged_attribute_names is None:
        ragged_attribute_tensors = []
        ragged_attribute_names = []
        for name, value in fsas.named_tensor_attr(include_scores=False):
            if isinstance(value, k2.RaggedInt):
                ragged_attribute_tensors.append(value)
                ragged_attribute_names.append(name)
    else:
        ragged_attribute_tensors = [
            getattr(fsas, name) for name in ragged_attribute_names
        ]
        for t in ragged_attribute_tensors:
            assert isinstance(t, k2.RaggedInt)

    if len(ragged_attribute_tensors) == 0:
        if ret_arc_map:
            arc_map = torch.arange(fsas.num_arcs,
                                   dtype=torch.int32,
                                   device=fsas.device)
            return (fsas, arc_map)
        else:
            return fsas

    (dest_arcs, dest_labels,
     arc_map) = _k2.expand_arcs(fsas.arcs, ragged_attribute_tensors)

    # The rest of this function is a modified version of
    # `fsa_from_unary_function_tensor()`.
    dest = Fsa(dest_arcs)

    # Handle the non-ragged attributes, and ragged attributes that
    # we're not linearizing.
    for name, value in fsas.named_tensor_attr(include_scores=False):
        if isinstance(value, torch.Tensor):
            filler = float(fsas.get_filler(name))
            setattr(dest, name,
                    index_select(value, arc_map, default_value=filler))
        elif name not in ragged_attribute_names:
            setattr(dest, name, index(value, arc_map))

    # Handle the attributes that were ragged but are now linear
    for name, value in zip(ragged_attribute_names, dest_labels):
        setattr(dest, name, value)

    # Copy non-tensor attributes
    for name, value in fsas.named_non_tensor_attr():
        setattr(dest, name, value)

    # make sure autograd works on the scores
    k2.autograd_utils.phantom_index_select_scores(dest, fsas.scores, arc_map)

    # Make sure -1's are only on final-arcs, and never on non-final arcs.
    if hasattr(dest, 'aux_labels'):
        _k2.fix_final_labels(dest.arcs, dest.aux_labels)

    if ret_arc_map:
        return dest, arc_map
    else:
        return dest