示例#1
0
    def backward(
            ctx, unused: torch.Tensor
    ) -> Tuple[None, None, None, torch.Tensor]:  # noqa
        fsas = ctx.fsas
        log_semiring = ctx.log_semiring
        use_float_scores = ctx.use_float_scores
        scores, = ctx.saved_tensors

        if log_semiring is False:
            entering_arcs = fsas.update_entering_arcs(use_float_scores)
            _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs)
            best_path_arc_indexes = ragged_int.values().to(torch.int64)
            out_grad = torch.zeros_like(scores, requires_grad=False)
            out_grad[best_path_arc_indexes] = 1
            # We return four values since the `forward` method accepts four
            # arguments (excluding ctx).
            #      fsas, log_semiring, use_float_scores, unused_scores
            return None, None, None, out_grad
        else:
            forward_scores = fsas.update_forward_scores_log(use_float_scores)
            backward_scores = fsas.update_backward_scores_log(use_float_scores)
            if use_float_scores:
                func = _k2._get_arc_scores_float
            else:
                func = _k2._get_arc_scores_double
            arc_scores = func(fsas=fsas.arcs,
                              forward_scores=forward_scores,
                              backward_scores=backward_scores)
            return None, None, None, arc_scores.exp()
示例#2
0
文件: fsa_algo.py 项目: OUC-lan/k2
def shortest_path(fsa: Fsa, use_double_scores: bool) -> Fsa:
    '''Return the shortest paths as linear FSAs from the start state
    to the final state in the tropical semiring.

    Note:
      It uses the opposite sign. That is, It uses `max` instead of `min`.

    Args:
      fsa:
        The input FSA. It can be either a single FSA or an FsaVec.
      use_double_scores:
        False to use float, i.e., single precision floating point, for scores.
        True to use double.

    Returns:
          FsaVec, it contains the best paths as linear FSAs
    '''
    entering_arcs = fsa.get_entering_arcs(use_double_scores)
    ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs)
    out_fsa = Fsa(ragged_arc)

    arc_map = ragged_int.values()
    for name, value in fsa.named_tensor_attr():
        setattr(out_fsa, name, index_attr(value, arc_map))

    for name, value in fsa.named_non_tensor_attr():
        setattr(out_fsa, name, value)

    return out_fsa
示例#3
0
def shortest_path(fsa: Fsa, use_float_scores: bool) -> Fsa:
    '''Return the shortest paths as linear FSAs from the start state
    to the final state in the tropical semiring.

    Note:
      It uses the opposite sign. That is, It uses `max` instead of `min`.

    Args:
      fsa:
        The input FSA. It can be either a single FSA or a FsaVec.
      use_float_scores:
        True to use float, i.e., single precision floating point, for scores.
        False to use double.

    Returns:
          FsaVec, it contains the best paths as linear FSAs
    '''
    entering_arcs = fsa.update_entering_arcs(use_float_scores)
    ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs)
    out_fsa = Fsa.from_ragged_arc(ragged_arc)

    arc_map = ragged_int.values().to(torch.int64)  # required by index_select
    for name, value in fsa.named_tensor_attr():
        setattr(out_fsa, name, value.index_select(0, arc_map))

    for name, value in fsa.named_non_tensor_attr():
        setattr(out_fsa, name, value)

    if hasattr(out_fsa, 'properties'):
        del out_fsa.properties

    return out_fsa
示例#4
0
文件: autograd.py 项目: zhu-han/k2
    def backward(
        ctx, tot_scores_grad: torch.Tensor
    ) -> Tuple[None, None, None, torch.Tensor]:  # noqa
        """
        Caution: this backward function uses a slightly indirect approach to
        compute the gradients.  Since the tot_scores are just computed as
        specific elements of `forward_scores`, the obvious way to get
        derivatives w.r.t. fsas.scores would be to set gradients w.r.t. the
        forward scores and then use BackpropGetForwardScores() to do the
        backprop.  But that might be a little slower than what we actually do.
        What we actually do is to compute the backward scores and use them and
        the forward scores to compute the posteriors, and let the derivs be the
        (posterior in FSA * loss_deriv w.r.t. that FSA's tot_prob).  The
        result is the same, and the underlying C++ code is simpler.
        (BackpropGetForwardScores() was added in order to compute slightly
        more difficult objective functions, that depend on the individual
        arc posteriors).
        """
        fsas = ctx.fsas
        log_semiring = ctx.log_semiring
        use_double_scores = ctx.use_double_scores
        scores, = ctx.saved_tensors

        if log_semiring is False:
            entering_arcs = fsas._get_entering_arcs(use_double_scores)
            _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs)
            if use_double_scores:
                scores_grad = _k2.get_tot_scores_double_tropical_backward(
                    fsas.arcs, ragged_int, tot_scores_grad)
            else:
                scores_grad = _k2.get_tot_scores_float_tropical_backward(
                    fsas.arcs, ragged_int, tot_scores_grad)
            # We return four values since the `forward` method accepts four
            # arguments (excluding ctx).
            #      fsas, log_semiring, use_double_scores, unused_scores
            return None, None, None, scores_grad
        else:
            arc_post = fsas._get_arc_post(use_double_scores, log_semiring)
            if use_double_scores:
                bprop_func = _k2.get_tot_scores_double_log_backward
            else:
                bprop_func = _k2.get_tot_scores_float_log_backward
            scores_grad = bprop_func(fsas.arcs, arc_post, tot_scores_grad)
            return None, None, None, scores_grad
示例#5
0
文件: fsa_algo.py 项目: entn-at/k2
def shortest_path(fsa: Fsa, use_double_scores: bool) -> Fsa:
    '''Return the shortest paths as linear FSAs from the start state
    to the final state in the tropical semiring.

    Note:
      It uses the opposite sign. That is, It uses `max` instead of `min`.

    Args:
      fsa:
        The input FSA. It can be either a single FSA or an FsaVec.
      use_double_scores:
        False to use float, i.e., single precision floating point, for scores.
        True to use double.

    Returns:
      FsaVec, it contains the best paths as linear FSAs
    '''
    entering_arcs = fsa._get_entering_arcs(use_double_scores)
    ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs)
    arc_map = ragged_int.values()

    out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map)
    return out_fsa
示例#6
0
    def backward(
        ctx, tot_scores_grad: torch.Tensor
    ) -> Tuple[None, None, None, torch.Tensor]:  # noqa
        fsas = ctx.fsas
        log_semiring = ctx.log_semiring
        use_float_scores = ctx.use_float_scores
        scores, = ctx.saved_tensors

        if log_semiring is False:
            entering_arcs = fsas.get_entering_arcs(use_float_scores)
            _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs)
            if use_float_scores:
                out_grad = _k2._get_tot_scores_float_tropical_backward(
                    fsas.arcs, ragged_int, tot_scores_grad)
            else:
                out_grad = _k2._get_tot_scores_double_tropical_backward(
                    fsas.arcs, ragged_int, tot_scores_grad)
            # We return four values since the `forward` method accepts four
            # arguments (excluding ctx).
            #      fsas, log_semiring, use_float_scores, unused_scores
            return None, None, None, out_grad
        else:
            forward_scores = fsas.get_forward_scores_log(use_float_scores)
            backward_scores = fsas.get_backward_scores_log(use_float_scores)
            if use_float_scores:
                func = _k2._get_arc_scores_float
                bprop_func = _k2._get_tot_scores_float_log_backward
            else:
                func = _k2._get_arc_scores_double
                bprop_func = _k2._get_tot_scores_double_log_backward

            arc_scores = func(fsas=fsas.arcs,
                              forward_scores=forward_scores,
                              backward_scores=backward_scores)
            out_grad = bprop_func(fsas.arcs, arc_scores, tot_scores_grad)
            return None, None, None, out_grad