Exemple #1
0
def _root_jvp(const_lengths, jaxprs, primals, tangents):
    params, _ = _split_root_args(primals, const_lengths)
    sol = _custom_root(const_lengths, jaxprs, *primals)

    f_out_vals = len(jaxprs.f.out_avals)
    solution, aux = split_list(sol, [f_out_vals])

    params_dot, _ = _split_root_args(tangents, const_lengths)

    # F(m, u) = 0      # system of equations in u, parameterized by m
    #                  # solution is u*(m) defined in a neighborhood
    # F(m, u*(m)) = 0  # satisfied in a neighborhood
    #
    # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0       # implied by line above
    # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m))  # rearrange
    #
    # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]]  # jvp

    f = core.jaxpr_as_fun(jaxprs.f)
    linearize_and_solve = partial(core.jaxpr_as_fun(jaxprs.l_and_s),
                                  *params.l_and_s)
    f_at_solution = lambda *params: f(*params, *solution)
    _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped(
        params.f, params_dot.f)
    solution_dot = _map(operator.neg, linearize_and_solve(*solution, *rhs))
    # append aux, create symbolic zero tangents for the aux values
    solution += aux
    solution_dot += _map(lax.zeros_like_array, aux)

    return solution, solution_dot
Exemple #2
0
def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr):
  primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr,
                                             ivjp_jaxpr=ivjp_jaxpr)
  fun = core.jaxpr_as_fun(fun_jaxpr)
  # FIXME: This might compute the primals multiple times, but we only need to do
  #        this trick while linearizing. It should be possible to do it through
  #        a custom partial eval rule.
  _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents)
  return primals_out, tangents_out
Exemple #3
0
def _tangent_linear_map(func, params, params_dot, *x):
  """Compute the tangent of a linear map.

  Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
  this function computes ``∂A @ x``.
  """
  assert any(p is not ad_util.zero for p in params_dot)
  zeros = [ad_util.zero] * len(x)
  _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
      params + list(x), params_dot + zeros)
  return out_tangent
Exemple #4
0
def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
  params = primals[:num_consts]
  solution = tuple(root_p.bind(*primals, num_consts=num_consts, jaxpr=jaxpr,
                               solve=solve, tangent_solve=tangent_solve))
  params_dot = tangents[:num_consts]

  # F(m, u) = 0      # system of equations in u, parameterized by m
  #                  # solution is u*(m) defined in a neighborhood
  # F(m, u*(m)) = 0  # satisfied in a neighborhood
  #
  # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0       # implied by line above
  # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m))  # rearrange
  #
  # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]]  # jvp

  f = core.jaxpr_as_fun(jaxpr)
  f_fixed_params = lambda *solution: f(*(params + solution))
  f_fixed_solution = lambda *params: f(*(params + solution))

  _, rhs = ad.jvp(lu.wrap_init(f_fixed_solution)).call_wrapped(params, params_dot)
  _, f_jvp_wrt_solution = api.linearize(f_fixed_params, *solution)
  solution_dot = [-x for x in tangent_solve(f_jvp_wrt_solution, *rhs)]

  return solution, solution_dot
Exemple #5
0
 def _jvp(primals, tangents, **params):
   return ad.jvp(lu.wrap_init(self.impl, params)).call_wrapped(primals,
                                                               tangents)
Exemple #6
0
def jvp_jaxpr(jaxpr):
  f = lu.wrap_init(jaxpr_as_fun(jaxpr))
  dimvars = dict((v, v.aval) for v in jaxpr.in_dim_binders)
  in_avals = [_replace_vars_with_avals(dimvars, v.aval) for v in jaxpr.in_binders]
  jaxpr, consts, _ = trace_to_jaxpr_dynamic(jvp_traceable(ad.jvp(f)), in_avals * 2)
  return jaxpr, consts