def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation if not self.fwd or not self.bwd: msg = "No VJP defined for custom_vjp function {} using defvjp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if config.jax_enable_custom_vjp_by_custom_transpose: if self.nondiff_argnums: raise NotImplementedError( 'nondiff_argnums not implemented for new custom_vjp') return custom_vjp_by_custom_transpose(self.fun, self.fwd, self.bwd)(*args) else: if self.nondiff_argnums: for i in self.nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums = set(self.nondiff_argnums) dyn_argnums = [i for i in range(len(args)) if i not in nondiff_argnums] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args, require_static_args_hashable=False) bwd = _add_args(lu.wrap_init(self.bwd), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) flat_fwd, out_trees = _flatten_fwd(fwd, in_tree) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees) fst, aux = lu.merge_linear_aux(out_tree, out_trees) out_tree = aux if fst else aux[0] return tree_unflatten(out_tree, out_flat)
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals, reduce_axes): # backward_pass can only transpose linear computations, but the call_jaxpr embedded in # remat contains primal (non-linear) equations too. Hence, we have to eliminate those # (in this case via partial_eval) before we call into backward_pass again. typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, []) unknowns = map(is_undefined_primal, primals_in) primal_jaxpr, tangent_jaxpr, out_unknowns = \ pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore def do_transpose(primals_in, cotangents_in): # NOTE: This is passing in undefined primals in place of tangent arguments, but it # should all work out, because we're only computing the primal part here. residuals = core.jaxpr_as_fun(primal_jaxpr)( *primals_in)[len(cotangents_in):] # Now that we have a purely linear jaxpr, we can transpose it cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in) # backward_pass will return cotangents computed for all invars, but some of them # are residuals appended by partial eval, so we need to skip those before we return. return cotangents_out[:len(primals_in)] flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in)) flat_do_transpose, out_tree = flatten_fun_nokwargs( lu.wrap_init(do_transpose), in_tree_def) flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args, **params) return tree_unflatten(out_tree(), flat_cotangents_out)
def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): in_batched_ps, in_batched_ts = in_batched mutually_batched = tree_map(operator.and_, in_batched_ps, in_batched_ts) extra_batched_ps = tree_map(lambda pb, tb: 0 if pb and not tb else None, in_batched_ps, in_batched_ts) extra_batched_ts = tree_map(lambda pb, tb: 0 if tb and not pb else None, in_batched_ps, in_batched_ts) out_mutually_batched = lu.Store() flat_ps_ts, tree_ps_ts = tree_flatten((primals, tangents)) flat_extra_batched_ps_ts, tree_ps_ts2 = tree_flatten( (extra_batched_ps, extra_batched_ts), is_leaf=lambda x: x is None) # TODO(frostig): assert these also equal: # treedef_tuple((in_tree, in_tree)) # once https://github.com/google/jax/issues/9066 is fixed assert tree_ps_ts == tree_ps_ts2 del tree_ps_ts2 def to_jvp(*primals): out, out_batched = call_rule(rule, axis_size, mutually_batched, primals) check_vmap_rule_trees( rule, out_tree, tree_structure(out), tree_structure(out_batched)) out_mutually_batched.store(out_batched) return out def to_vmap_over_extra_batched_dims(primals, tangents): return jax.jvp(to_jvp, primals, tangents) to_vmap_over_extra_batched_dims_flat, out_tree2 = flatten_fun_nokwargs( lu.wrap_init(to_vmap_over_extra_batched_dims), tree_ps_ts) flat_out_ps_ts, flat_out_axes = vmap_unrestricted( to_vmap_over_extra_batched_dims_flat, *flat_ps_ts, in_axes=flat_extra_batched_ps_ts, axis_name=core.no_axis_name, axis_size=axis_size) n, ragged = divmod(len(flat_out_ps_ts), 2) assert not ragged flat_out_ps, flat_out_ts = flat_out_ps_ts[:n], flat_out_ps_ts[n:] flat_out_axes_p, flat_out_axes_t = flat_out_axes[:n], flat_out_axes[n:] flat_out_ps = map(maybe_bdim_at_front, flat_out_ps, flat_out_axes_p) flat_out_extra_batched_ps = [d is not not_mapped for d in flat_out_axes_p] flat_out_ts = map(maybe_bdim_at_front, flat_out_ts, flat_out_axes_t) flat_out_extra_batched_ts = [d is not not_mapped for d in flat_out_axes_t] out_ps, out_ts = tree_unflatten( out_tree2(), [*flat_out_ps, *flat_out_ts]) out_extra_batched_ps, out_extra_batched_ts = tree_unflatten( out_tree2(), [*flat_out_extra_batched_ps, *flat_out_extra_batched_ts]) out_batched_ps = tree_map( operator.or_, out_mutually_batched.val, out_extra_batched_ps) out_batched_ts = tree_map( operator.or_, out_mutually_batched.val, out_extra_batched_ts) return (out_ps, out_ts), (out_batched_ps, out_batched_ts)
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation if not self.jvp: msg = "No JVP defined for custom_jvp function {} using defjvp." raise AttributeError(msg.format(self.__name__)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple( _stop_gradient(x) if i in nondiff_argnums else x for i, x in enumerate(args)) diff_argnums = [ i for i in range(len(args)) if i not in nondiff_argnums ] f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args, require_static_args_hashable=False) static_args = [args[i] for i in self.nondiff_argnums] jvp = _add_args(lu.wrap_init(self.jvp), static_args) else: f_, dyn_args = lu.wrap_init(self.fun), args jvp = lu.wrap_init(self.jvp) args_flat, in_tree = tree_flatten(dyn_args) flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree) flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat) _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2) return tree_unflatten(out_tree, out_flat)
def fwd(*args, **kwargs): ans, rule = fun(*args, **kwargs) ans_flat, out_tree = tree_flatten((ans, )) rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat] jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
def wrapped(spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any) -> Tuple[Sequence[SparsifyValue], bool]: spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue) in_avals_flat = spvalues_to_avals(spenv, spvalues_flat) wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree) jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat) result = eval_sparse(jaxpr, consts, spvalues_flat, spenv) if len(out_avals_flat) != len(result): raise Exception("Internal: eval_sparse does not return expected number of arguments. " "Got {result} for avals {out_avals_flat}") return result, out_tree()
def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals, reduce_axes): call_jaxpr = _close_jaxpr(call_jaxpr) unknowns = map(is_undefined_primal, primals_in) primal_jaxpr, tangent_jaxpr, _ = \ pe.partial_eval_jaxpr(call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore args, in_tree_def = tree_flatten((primals_in, cotangents_in)) transpose = lu.hashable_partial(lu.wrap_init(_remat_transpose), primal_jaxpr, tangent_jaxpr, reduce_axes) flat_transpose, out_tree = flatten_fun_nokwargs(transpose, in_tree_def) flat_cotangents_out = pe.remat_call_p.bind(flat_transpose, *args, **params) return tree_unflatten(out_tree(), flat_cotangents_out)
def __call__(self, *args, **kwargs): if self.ivjp is None: msg = "No IVJP defined for custom_vjp function {}. Did you forget to use defivjp?" raise AttributeError(msg.format(self.__name__)) args = custom_derivatives._resolve_kwargs(self.fun, args, kwargs) # TODO: Support nondiff_argnums fun, ivjp = lu.wrap_init(self.fun), lu.wrap_init(self.ivjp) args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) flat_ivjp = _flatten_ivjp(ivjp, in_tree, out_tree) out_flat = _custom_ivjp(flat_fun, flat_ivjp, args_flat) return tree_unflatten(out_tree(), out_flat)
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) new_params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) out_flat = primitive.bind(fun, *all_args, **new_params) return tree_unflatten(out_tree(), out_flat)
def fwd(*args): flat_args, in_tree = tree_flatten(args) in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args) fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals) # TODO: Don't warn if consts contain JVP tracers? if consts: warnings.warn("Values that an @invertible function closes over will not have their " + "gradients computed correctly (their uses inside this function will be ignored)!") # TODO: This requires the body to be jittable, but this shouldn't be necessary. # Is there a way to trace a jaxpr while running it? flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args) return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree)))
def __call__(self, *args, **kwargs): assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=self.vmap_rule, in_tree=in_tree) return tree_unflatten(out_tree(), out_flat)
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) if not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'transpose')) update_params = call_transpose_param_updaters.get(primitive) if update_params: params = update_params(params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) if config.jax_dynamic_shapes: in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args] fun = lu.annotate(fun, tuple(in_type)) out_flat = primitive.bind(fun, *all_args, **params) return tree_unflatten(out_tree(), out_flat)
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False) fun, nz_arg_cts = nonzero_outputs(fun) fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) # Preserve axis for primal arguments, skip tangents (represented as undefined primals). in_axes, out_axes = params['in_axes'], params['out_axes'] new_in_axes = (*[ axis for axis, x in zip(in_axes, args) if not is_undefined_primal(x) ], *[axis for axis, x in zip(out_axes, ct) if type(x) is not Zero]) # The interim strategy we use below (until avals-with-names) only works # when all outputs are mapped. assert all(out_axis is not None for out_axis in out_axes), out_axes # NOTE: This assumes that the output cotangents being zero is a deterministic # function of which input cotangents were zero. @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct))) def out_axes_thunk(): return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz) new_params = dict(params, name=wrap_name(params['name'], 'transpose'), in_axes=new_in_axes, out_axes_thunk=out_axes_thunk) del new_params['out_axes'] update_params = call_transpose_param_updaters.get(primitive) if update_params: new_params = update_params(new_params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) out_flat = primitive.bind(fun, *all_args, **new_params) arg_cts = tree_unflatten(out_tree(), out_flat) # The freevars are being fanned out (not mapped). During transpose the # dual of fan-out is fan-in-sum. We apply it to the unmapped invars. assert len(in_axes) == len(arg_cts) def unmap_zero(zero, in_axis): return (zero if in_axis is None else Zero( core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval))) arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else arg_ct if in_axis is not None else arg_ct.sum(0) for arg_ct, in_axis in zip(arg_cts, in_axes)) return tuple(arg_cts)
def _closure_convert_for_avals(fun, in_tree, in_avals): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) out_tree = out_tree() (closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts) num_consts = len(hoisted_consts) def converted_fun(*args_hconsts): num_args = len(args_hconsts) - num_consts args, hoisted_consts = split_list(args_hconsts, [num_args]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten(tuple(args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def __call__(self, residual_arg, linear_arg): res_arg, lin_arg = residual_arg, linear_arg _, res_tree = tree_flatten(res_arg) _, lin_tree = tree_flatten(lin_arg) args_flat, in_tree = tree_flatten((res_arg, lin_arg)) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_transpose_p.bind(*consts, *args_flat, call=closed_call, rule=self.transpose, lin_tree=lin_tree, res_tree=res_tree, out_tree=out_tree()) return tree_unflatten(out_tree(), out_flat)
def linear_call(fun: Callable, fun_transpose: Callable, residual_args, linear_args): """Call a linear function, with a custom implementation for its transpose. The type signatures of ``fun`` and ``fun_transpose`` are: .. code-block:: haskell fun :: r -> a -o b fun_transpose :: r -> b -o a where the ``-o`` arrow indicates a linear function, ``r`` is the residual input type and ``a`` is the linear input type. The functions ``fun`` and ``fun_transpose`` are coupled as transposes of one another. Specifically, the transpose of a ``linear_call`` primitive is another ``linear_call`` to ``fun_transpose``, with ``fun`` as its custom transposition. For example: >>> def f(r, x): ... return x / r >>> def t(r, t): ... return t / r >>> def div_add(x, denom): ... return x + linear_call(f, t, denom, x) >>> def transpose(f, x_example): ... def transposed(y): ... x, = jax.linear_transpose(f, x_example)(y) ... return x ... return transposed >>> div_add(9., 3.) DeviceArray(12., dtype=float32, weak_type=True) >>> transpose(partial(div_add, denom=3.), 1.)(18.) # custom DeviceArray(24., dtype=float32, weak_type=True) >>> transpose(lambda x: x + x / 3., 1.)(18.) # reference DeviceArray(24., dtype=float32, weak_type=True) The above definition of ``f`` illustrates the purpose of a residual argument: division is linear in one of its inputs (the dividend ``x``) but not the other (the divisor ``r``). As another example: >>> def custom_id(x): ... def f(_, x): return x ... def t(_, t): return 7. ... return linear_call(f, t, (), x) >>> custom_id(1.) 1.0 >>> transpose(custom_id, 1.)(1.) 7.0 >>> transpose(transpose(custom_id, 1.), 1.)(1.) 1.0 >>> transpose(transpose(transpose(custom_id, 1.), 1.), 1.)(1.) 7.0 Args: fun: a Python callable specifying a linear function. It should take two arguments: one of "residual" inputs (type ``r``), i.e. inputs in which the function is not necessarly linear, and one of "linear" inputs (type ``a``). It should return output whose components are linear in the linear input (type ``b``). fun_transpose: a Python callable specifying a structurally linear function that is the transpose of ``fun`` with respect to its linear inputs. Its first argument is the same residual inputs (``r``) as ``fun``. Its second argument is of type ``b``. Finally, its output is of type ``a`` and each of its component are linear in its second argument (the ``b`` inputs). residual_args: Argument in which ``fun`` and ``fun_transpose`` are not necessarily linear. Not involved in transposition. linear_args: Argument in which ``fun`` and ``fun_transpose`` are linear and with respect to which the two are transposes. Returns: The call result, i.e. ``fun(residual_args, linear_args)``. """ operands_res, res_tree = tree_flatten(residual_args) operands_lin, lin_tree = tree_flatten(linear_args) f_in_tree = treedef_tuple((res_tree, lin_tree)) f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree) res_avals = map(abstractify, operands_res) lin_avals = map(abstractify, operands_lin) f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals)) f_jaxpr = _close_jaxpr(f_jaxpr) out_avals = map(core.raise_to_shaped, f_jaxpr.out_avals) t_in_tree = treedef_tuple((res_tree, out_tree())) t, t_out_tree = flatten_fun_nokwargs(lu.wrap_init(fun_transpose), t_in_tree) t_jaxpr, t_consts = _initial_style_jaxpr(t, (*res_avals, *out_avals)) t_jaxpr = _close_jaxpr(t_jaxpr) if t_out_tree() != lin_tree: raise TypeError( 'transpose output pytree structure must match that of linear inputs, ' f'got output structure {t_out_tree()} ' f'and input structure {lin_tree}.') out = linear_call_p.bind(*f_consts, *t_consts, *operands_res, *operands_lin, callee=f_jaxpr, transpose=t_jaxpr, num_callee_consts=len(f_consts), num_transpose_consts=len(t_consts), num_res=len(operands_res)) return tree_unflatten(out_tree(), out)
def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): if all(type(ct) is ad.Zero for ct in cotangents_in): return map(lambda v: ad.Zero(v.aval), jaxpr.invars) def write_cotangent(v, ct): # assert v not in primal_env if ct is not None and type(v) is not Literal: ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct def read_cotangent(v): return ct_env.get(v, ad.Zero(v.aval)) def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, ad.UndefinedPrimal(v.aval)) def write_primal(v, val): if type(v) is Literal: return primal_env.setdefault(v, val) # Invert while computing cotangents ct_env: Dict[Any, Any] = {} primal_env: Dict[Any, Any] = {} write_primal(core.unitvar, core.unit) map(write_primal, jaxpr.invars, primals_in) map(write_primal, jaxpr.outvars, primals_out) map(write_primal, jaxpr.constvars, consts) map(write_cotangent, jaxpr.outvars, cotangents_in) for eqn in jaxpr.eqns[::-1]: primals_in = map(read_primal, eqn.invars) primals_out = map(read_primal, eqn.outvars) cts_in = map(read_cotangent, eqn.outvars) should_invert = any(type(primal) is not ad.UndefinedPrimal for primal in primals_out) should_vjp = any(type(ct) is not ad.Zero for ct in cts_in) assert not eqn.primitive.call_primitive # Skip primals equations that are only jvp coefficients and don't affect # primal outputs. if not should_invert and not should_vjp: continue def abstract(value): return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value)) # Get the ivjp_jaxpr if eqn.primitive is custom_ivjp_p: ivjp_jaxpr = eqn.params['ivjp_jaxpr'] else: if eqn.primitive in primitive_ivjps: complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive]) else: complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in))) _, in_tree = tree_flatten( tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out))) complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree) in_avals = map(abstract, primals_in + primals_out + primals_out) # TODO: Actually we do know some of the inputs, because they might be literals! ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True) assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, []) # Once we know what the ivjp can do exactly, we have to isolate the part we are # actually able to compute with the values we have at hand. num_inputs = len(eqn.invars) unknowns = (map(ad.is_undefined_primal, primals_in) + map(ad.is_undefined_primal, primals_out) + [False] * len(cts_in)) jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore ivjp_jaxpr, unknowns, instantiate=False) # type:ignore unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs]) # Make sure we're able to compute all cotangents. We don't really care if we # can reconstruct or primals or not, although failure to do so might result in # failing to compute cotangents later. assert not any(unknown_cotangents) # Remove residual outputs -- we won't be computing the unknown jaxpr anyway. num_outputs = len(jaxpr_unknown.jaxpr.outvars) jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs] # TODO: We could drop the outputs that correspond to primals that we already know. # This only matters in eager mode, so leaving it out for now... ivjp = core.jaxpr_as_fun(jaxpr_known) rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in), [num_inputs]) # Unknown rec_primals_in are core.units, so we have to replace them # with UnknownPrimals because that's what write_primal will ignore. rec_primals_in = [prev if unknown else rec for prev, rec, unknown in zip(primals_in, rec_primals_in, unknown_rec_primals_in)] map(write_primal, eqn.invars, rec_primals_in) map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out) # NOTE: We keep the cotangents associated with primal variables, while the contract of a # transpose is to return them in positions associated with tangent variables, which # is what causes this whole confusion. return map(read_cotangent, jaxpr.invars)
def wrapped_fun(*args): args_flat, in_tree = tree_flatten(args) f = lu.wrap_init(fun) flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree) out_flat = callback_fun(flat_fun, args_flat, callback, strip_calls) return tree_unflatten(out_tree(), out_flat)
def _wrapped(*args): args_flat, in_tree = tree_flatten(args, is_leaf=_is_bcoo) wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) out = sparsify_fun(wrapped_fun, args_flat) return tree_unflatten(out_tree(), out)