def backward( ctx, unused: torch.Tensor ) -> Tuple[None, None, None, torch.Tensor]: # noqa fsas = ctx.fsas log_semiring = ctx.log_semiring use_float_scores = ctx.use_float_scores scores, = ctx.saved_tensors if log_semiring is False: entering_arcs = fsas.update_entering_arcs(use_float_scores) _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs) best_path_arc_indexes = ragged_int.values().to(torch.int64) out_grad = torch.zeros_like(scores, requires_grad=False) out_grad[best_path_arc_indexes] = 1 # We return four values since the `forward` method accepts four # arguments (excluding ctx). # fsas, log_semiring, use_float_scores, unused_scores return None, None, None, out_grad else: forward_scores = fsas.update_forward_scores_log(use_float_scores) backward_scores = fsas.update_backward_scores_log(use_float_scores) if use_float_scores: func = _k2._get_arc_scores_float else: func = _k2._get_arc_scores_double arc_scores = func(fsas=fsas.arcs, forward_scores=forward_scores, backward_scores=backward_scores) return None, None, None, arc_scores.exp()
def shortest_path(fsa: Fsa, use_double_scores: bool) -> Fsa: '''Return the shortest paths as linear FSAs from the start state to the final state in the tropical semiring. Note: It uses the opposite sign. That is, It uses `max` instead of `min`. Args: fsa: The input FSA. It can be either a single FSA or an FsaVec. use_double_scores: False to use float, i.e., single precision floating point, for scores. True to use double. Returns: FsaVec, it contains the best paths as linear FSAs ''' entering_arcs = fsa.get_entering_arcs(use_double_scores) ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs) out_fsa = Fsa(ragged_arc) arc_map = ragged_int.values() for name, value in fsa.named_tensor_attr(): setattr(out_fsa, name, index_attr(value, arc_map)) for name, value in fsa.named_non_tensor_attr(): setattr(out_fsa, name, value) return out_fsa
def shortest_path(fsa: Fsa, use_float_scores: bool) -> Fsa: '''Return the shortest paths as linear FSAs from the start state to the final state in the tropical semiring. Note: It uses the opposite sign. That is, It uses `max` instead of `min`. Args: fsa: The input FSA. It can be either a single FSA or a FsaVec. use_float_scores: True to use float, i.e., single precision floating point, for scores. False to use double. Returns: FsaVec, it contains the best paths as linear FSAs ''' entering_arcs = fsa.update_entering_arcs(use_float_scores) ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs) out_fsa = Fsa.from_ragged_arc(ragged_arc) arc_map = ragged_int.values().to(torch.int64) # required by index_select for name, value in fsa.named_tensor_attr(): setattr(out_fsa, name, value.index_select(0, arc_map)) for name, value in fsa.named_non_tensor_attr(): setattr(out_fsa, name, value) if hasattr(out_fsa, 'properties'): del out_fsa.properties return out_fsa
def backward( ctx, tot_scores_grad: torch.Tensor ) -> Tuple[None, None, None, torch.Tensor]: # noqa """ Caution: this backward function uses a slightly indirect approach to compute the gradients. Since the tot_scores are just computed as specific elements of `forward_scores`, the obvious way to get derivatives w.r.t. fsas.scores would be to set gradients w.r.t. the forward scores and then use BackpropGetForwardScores() to do the backprop. But that might be a little slower than what we actually do. What we actually do is to compute the backward scores and use them and the forward scores to compute the posteriors, and let the derivs be the (posterior in FSA * loss_deriv w.r.t. that FSA's tot_prob). The result is the same, and the underlying C++ code is simpler. (BackpropGetForwardScores() was added in order to compute slightly more difficult objective functions, that depend on the individual arc posteriors). """ fsas = ctx.fsas log_semiring = ctx.log_semiring use_double_scores = ctx.use_double_scores scores, = ctx.saved_tensors if log_semiring is False: entering_arcs = fsas._get_entering_arcs(use_double_scores) _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs) if use_double_scores: scores_grad = _k2.get_tot_scores_double_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) else: scores_grad = _k2.get_tot_scores_float_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) # We return four values since the `forward` method accepts four # arguments (excluding ctx). # fsas, log_semiring, use_double_scores, unused_scores return None, None, None, scores_grad else: arc_post = fsas._get_arc_post(use_double_scores, log_semiring) if use_double_scores: bprop_func = _k2.get_tot_scores_double_log_backward else: bprop_func = _k2.get_tot_scores_float_log_backward scores_grad = bprop_func(fsas.arcs, arc_post, tot_scores_grad) return None, None, None, scores_grad
def shortest_path(fsa: Fsa, use_double_scores: bool) -> Fsa: '''Return the shortest paths as linear FSAs from the start state to the final state in the tropical semiring. Note: It uses the opposite sign. That is, It uses `max` instead of `min`. Args: fsa: The input FSA. It can be either a single FSA or an FsaVec. use_double_scores: False to use float, i.e., single precision floating point, for scores. True to use double. Returns: FsaVec, it contains the best paths as linear FSAs ''' entering_arcs = fsa._get_entering_arcs(use_double_scores) ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs) arc_map = ragged_int.values() out_fsa = k2.utils.fsa_from_unary_function_tensor(fsa, ragged_arc, arc_map) return out_fsa
def backward( ctx, tot_scores_grad: torch.Tensor ) -> Tuple[None, None, None, torch.Tensor]: # noqa fsas = ctx.fsas log_semiring = ctx.log_semiring use_float_scores = ctx.use_float_scores scores, = ctx.saved_tensors if log_semiring is False: entering_arcs = fsas.get_entering_arcs(use_float_scores) _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs) if use_float_scores: out_grad = _k2._get_tot_scores_float_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) else: out_grad = _k2._get_tot_scores_double_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) # We return four values since the `forward` method accepts four # arguments (excluding ctx). # fsas, log_semiring, use_float_scores, unused_scores return None, None, None, out_grad else: forward_scores = fsas.get_forward_scores_log(use_float_scores) backward_scores = fsas.get_backward_scores_log(use_float_scores) if use_float_scores: func = _k2._get_arc_scores_float bprop_func = _k2._get_tot_scores_float_log_backward else: func = _k2._get_arc_scores_double bprop_func = _k2._get_tot_scores_double_log_backward arc_scores = func(fsas=fsas.arcs, forward_scores=forward_scores, backward_scores=backward_scores) out_grad = bprop_func(fsas.arcs, arc_scores, tot_scores_grad) return None, None, None, out_grad