def func(): with forwardprop.ForwardGradientAccumulator() as acc: acc.watch(m, tangent) result = matmul(m, m, transpose_b=True) return result, acc.jvp(result)
def _compute_forwardgrad(primal): tangent = constant_op.constant(1.) with forwardprop.ForwardGradientAccumulator() as acc: acc.watch(primal, tangent) primal_out = f(primal) return acc.jvp(primal_out)
def compiled_function(x, tangent): with forwardprop.ForwardGradientAccumulator() as acc: acc.watch(x, tangent) result = math_ops.matmul(x, x, transpose_b=True) return result, acc.jvp(result)