Пример #1
0
    def forward(ctx, fsas: Fsa, out_fsa: List[Fsa],
                unused_fsas_scores: torch.Tensor) -> torch.Tensor:
        '''Compute the union of all fsas in a FsaVec.

        Args:
          fsas:
            The input FsaVec. Caution: We require that each fsa in the FsaVec
            is non-empty (i.e., with at least two states).
          out_fsa:
            A list containing one entry. Since this function can only return
            values of type `torch.Tensor`, we return the union result in the
            list.
          unused_fsas_scores:
            It is the same as `fsas.scores`, whose sole purpose is for autograd.
            It is not used in this function.
        '''
        need_arc_map = True
        ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)
        out_fsa[0] = Fsa(ragged_arc)

        for name, value in fsas.named_tensor_attr(include_scores=False):
            value = k2.index(value, arc_map)
            setattr(out_fsa[0], name, value)

        for name, value in fsas.named_non_tensor_attr():
            setattr(out_fsa[0], name, value)
        ctx.arc_map = arc_map
        ctx.save_for_backward(unused_fsas_scores)

        return out_fsa[0].scores  # the return value will be discarded
Пример #2
0
def union(fsas: Fsa) -> Fsa:
    '''Compute the union of a FsaVec.

    Args:
      fsas:
        A FsaVec. That is, len(fsas.shape) == 3.
    Returns:
      A single Fsa that is the union of the input fsas.
    '''
    # TODO(fangjun): change it to True once arc_map is implemented
    need_arc_map = False

    ragged_arc, _ = _k2.union(fsas.arcs, need_arc_map)

    out_fsa = Fsa.from_ragged_arc(ragged_arc)

    # TODO(fangjun): copy attr from the input fsas to out_fsa
    return out_fsa
Пример #3
0
def union(fsas: Fsa) -> Fsa:
    '''Compute the union of a FsaVec.

    Caution:
      We require that every fsa in fsas is non-empty, i.e.,
      contains at least two states

    Args:
      fsas:
        A FsaVec. That is, len(fsas.shape) == 3.

    Returns:
      A single Fsa that is the union of the input fsas.
    '''
    need_arc_map = True
    ragged_arc, arc_map = _k2.union(fsas.arcs, need_arc_map)

    out_fsa = k2.utils.fsa_from_unary_function_tensor(fsas, ragged_arc, arc_map)
    return out_fsa