def backward( ctx, tot_scores_grad: torch.Tensor ) -> Tuple[None, None, None, torch.Tensor]: # noqa """ Caution: this backward function uses a slightly indirect approach to compute the gradients. Since the tot_scores are just computed as specific elements of `forward_scores`, the obvious way to get derivatives w.r.t. fsas.scores would be to set gradients w.r.t. the forward scores and then use BackpropGetForwardScores() to do the backprop. But that might be a little slower than what we actually do. What we actually do is to compute the backward scores and use them and the forward scores to compute the posteriors, and let the derivs be the (posterior in FSA * loss_deriv w.r.t. that FSA's tot_prob). The result is the same, and the underlying C++ code is simpler. (BackpropGetForwardScores() was added in order to compute slightly more difficult objective functions, that depend on the individual arc posteriors). """ fsas = ctx.fsas log_semiring = ctx.log_semiring use_double_scores = ctx.use_double_scores scores, = ctx.saved_tensors if log_semiring is False: entering_arcs = fsas._get_entering_arcs(use_double_scores) _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs) if use_double_scores: scores_grad = _k2.get_tot_scores_double_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) else: scores_grad = _k2.get_tot_scores_float_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) # We return four values since the `forward` method accepts four # arguments (excluding ctx). # fsas, log_semiring, use_double_scores, unused_scores return None, None, None, scores_grad else: arc_post = fsas._get_arc_post(use_double_scores, log_semiring) if use_double_scores: bprop_func = _k2.get_tot_scores_double_log_backward else: bprop_func = _k2.get_tot_scores_float_log_backward scores_grad = bprop_func(fsas.arcs, arc_post, tot_scores_grad) return None, None, None, scores_grad
def backward( ctx, tot_scores_grad: torch.Tensor ) -> Tuple[None, None, None, torch.Tensor]: # noqa fsas = ctx.fsas log_semiring = ctx.log_semiring use_double_scores = ctx.use_double_scores scores, = ctx.saved_tensors if log_semiring is False: entering_arcs = fsas.get_entering_arcs(use_double_scores) _, ragged_int = _k2.shortest_path(fsas.arcs, entering_arcs) if use_double_scores: out_grad = _k2.get_tot_scores_double_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) else: out_grad = _k2.get_tot_scores_float_tropical_backward( fsas.arcs, ragged_int, tot_scores_grad) # We return four values since the `forward` method accepts four # arguments (excluding ctx). # fsas, log_semiring, use_double_scores, unused_scores return None, None, None, out_grad else: forward_scores = fsas.get_forward_scores_log(use_double_scores) backward_scores = fsas.get_backward_scores_log(use_double_scores) if use_double_scores: func = _k2.get_arc_scores_double bprop_func = _k2.get_tot_scores_double_log_backward else: func = _k2.get_arc_scores_float bprop_func = _k2.get_tot_scores_float_log_backward arc_scores = func(fsas=fsas.arcs, forward_scores=forward_scores, backward_scores=backward_scores) out_grad = bprop_func(fsas.arcs, arc_scores, tot_scores_grad) return None, None, None, out_grad