def _root_jvp(const_lengths, jaxprs, primals, tangents): params, _ = _split_root_args(primals, const_lengths) sol = _custom_root(const_lengths, jaxprs, *primals) f_out_vals = len(jaxprs.f.out_avals) solution, aux = split_list(sol, [f_out_vals]) params_dot, _ = _split_root_args(tangents, const_lengths) # F(m, u) = 0 # system of equations in u, parameterized by m # # solution is u*(m) defined in a neighborhood # F(m, u*(m)) = 0 # satisfied in a neighborhood # # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange # # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp f = core.jaxpr_as_fun(jaxprs.f) linearize_and_solve = partial(core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s) f_at_solution = lambda *params: f(*params, *solution) _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped( params.f, params_dot.f) solution_dot = _map(operator.neg, linearize_and_solve(*solution, *rhs)) # append aux, create symbolic zero tangents for the aux values solution += aux solution_dot += _map(lax.zeros_like_array, aux) return solution, solution_dot
def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr): primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr, ivjp_jaxpr=ivjp_jaxpr) fun = core.jaxpr_as_fun(fun_jaxpr) # FIXME: This might compute the primals multiple times, but we only need to do # this trick while linearizing. It should be possible to do it through # a custom partial eval rule. _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents) return primals_out, tangents_out
def _tangent_linear_map(func, params, params_dot, *x): """Compute the tangent of a linear map. Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``, this function computes ``∂A @ x``. """ assert any(p is not ad_util.zero for p in params_dot) zeros = [ad_util.zero] * len(x) _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( params + list(x), params_dot + zeros) return out_tangent
def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve): params = primals[:num_consts] solution = tuple(root_p.bind(*primals, num_consts=num_consts, jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve)) params_dot = tangents[:num_consts] # F(m, u) = 0 # system of equations in u, parameterized by m # # solution is u*(m) defined in a neighborhood # F(m, u*(m)) = 0 # satisfied in a neighborhood # # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange # # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp f = core.jaxpr_as_fun(jaxpr) f_fixed_params = lambda *solution: f(*(params + solution)) f_fixed_solution = lambda *params: f(*(params + solution)) _, rhs = ad.jvp(lu.wrap_init(f_fixed_solution)).call_wrapped(params, params_dot) _, f_jvp_wrt_solution = api.linearize(f_fixed_params, *solution) solution_dot = [-x for x in tangent_solve(f_jvp_wrt_solution, *rhs)] return solution, solution_dot
def _jvp(primals, tangents, **params): return ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped(primals, tangents)
def jvp_jaxpr(jaxpr): f = lu.wrap_init(jaxpr_as_fun(jaxpr)) dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders) in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders] jaxpr, consts, _ = trace_to_jaxpr_dynamic(jvp_traceable(ad.jvp(f)), in_avals * 2) return jaxpr, consts