def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): cond_const_tracers, body_const_tracers, init_tracers = split_list( tracers, [cond_nconsts, body_nconsts]) init_avals = safe_map(lambda x: x.aval, init_tracers) cond_const_vals, body_const_vals, init_vals = tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) body_fun = jaxpr_as_fun(body_jaxpr) cond_fun = jaxpr_as_fun(cond_jaxpr) def cond(*carry): return cond_fun(*it.chain(cond_const_vals, carry)) def body(*carry): return body_fun(*it.chain(body_const_vals, carry)) new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls) # type: ignore new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(init_avals) new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals)) new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals)) out = lcf.while_p.bind( *it.chain(new_cond_consts, new_body_consts, init_vals), cond_nconsts=len(new_cond_consts), body_nconsts=len(new_body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) return safe_map(trace.pure, out)
def _root_jvp(const_lengths, jaxprs, primals, tangents): params, _ = _split_root_args(primals, const_lengths) sol = _custom_root(const_lengths, jaxprs, *primals) f_out_vals = len(jaxprs.f.out_avals) solution, aux = split_list(sol, [f_out_vals]) params_dot, _ = _split_root_args(tangents, const_lengths) # F(m, u) = 0 # system of equations in u, parameterized by m # # solution is u*(m) defined in a neighborhood # F(m, u*(m)) = 0 # satisfied in a neighborhood # # ∂_0 F(m, u*(m)) + ∂_1 F(m, u*(m)) ∂ u*(m) = 0 # implied by line above # ∂ u*(m) = - (∂_1 F(m, u*(m)))^{-1} ∂_0 F(m, u*(m)) # rearrange # # ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp f = core.jaxpr_as_fun(jaxprs.f) linearize_and_solve = partial(core.jaxpr_as_fun(jaxprs.l_and_s), *params.l_and_s) f_at_solution = lambda *params: f(*params, *solution) _, rhs = ad.jvp(lu.wrap_init(f_at_solution)).call_wrapped( params.f, params_dot.f) solution_dot = _map(operator.neg, linearize_and_solve(*solution, *rhs)) # append aux, create symbolic zero tangents for the aux values solution += aux solution_dot += _map(lax.zeros_like_array, aux) return solution, solution_dot
def _cond_batching_rule(args, dims, true_jaxpr, false_jaxpr, true_nconsts, false_nconsts): # TODO: maybe avoid moving arg axes to front if we're promoting to select? args = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(args, dims)] true_nops = len(true_jaxpr.in_avals) - true_nconsts (pred,), true_consts, true_ops, false_consts, false_ops = split_list( args, [1, true_nconsts, true_nops, false_nconsts]) size, = {x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped} orig_bat = [d is not batching.not_mapped for d in dims] (pred_bat,), tconst_bat, t_bat, fconst_bat, f_bat = split_list( orig_bat, [1, true_nconsts, true_nops, false_nconsts]) _, true_out_bat = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, False) _, false_out_bat = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, False) out_bat = [a or b for a, b in zip(true_out_bat, false_out_bat)] true_jaxpr_batched, _ = batching.batch_jaxpr(true_jaxpr, size, tconst_bat + t_bat, out_bat) false_jaxpr_batched, _ = batching.batch_jaxpr(false_jaxpr, size, fconst_bat + f_bat, out_bat) if pred_bat: true_out = core.jaxpr_as_fun(true_jaxpr_batched)(*(true_consts + true_ops)) false_out = core.jaxpr_as_fun(false_jaxpr_batched)(*(false_consts + false_ops)) true_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(true_out, out_bat)] false_out = [batching.broadcast(x, size, 0) if not b else x for x, b in zip(false_out, out_bat)] return [_cond_pred_bcast_select(pred, t, f) for t, f in zip(true_out, false_out)], [0] * len(true_out) else: out_dims = [0 if b else batching.not_mapped for b in out_bat] return cond_p.bind( *itertools.chain([pred], true_consts, true_ops, false_consts, false_ops), true_jaxpr=true_jaxpr_batched, false_jaxpr=false_jaxpr_batched, true_nconsts=len(true_consts), false_nconsts=len(false_consts)), out_dims
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error): cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) def new_body_f(*vals): out = body_f(*vals) _ = cond_f(*out) # this checks if the next cond application will error return out return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error): cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) def new_body_f(*vals): _ = cond_f(*vals) return body_f(*vals) return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Reaps the body of a while loop to get the reaps of the final iteration.""" cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) _, init_avals = tree_util.tree_map(lambda x: x.aval, (body_const_tracers, init_tracers)) cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) context = trace_util.get_dynamic_context(trace) settings = context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, *(body_const_tracers + init_tracers)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode != 'clobber': raise ValueError( f'Must use clobber mode for \'{k}\' inside of a `while_loop`.') reap_avals = {k: v['aval'] for k, v in body_metadata.items()} cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr) body_fun = jax_core.jaxpr_as_fun(body_jaxpr) reap_settings = dict(tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) def new_cond(carry, _): return cond_fun(*(cond_const_vals + carry)) def new_body(carry, _): carry, reaps = call_and_reap( body_fun, **reap_settings)(*(body_const_vals + carry)) return (carry, reaps) new_in_avals, new_in_tree = tree_util.tree_flatten( (init_avals, reap_avals)) new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_cond, new_in_tree, tuple(new_in_avals)) new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, new_in_tree, tuple(new_in_avals)) dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), reap_avals) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals), cond_nconsts=len(cond_consts), body_nconsts=len(body_consts), cond_jaxpr=new_cond_jaxpr, body_jaxpr=new_body_jaxpr) out = jax_util.safe_map(trace.pure, out) out, reaps = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode']) return out
def _cond_impl(pred, *args, **kwargs): true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict( kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"]) true_consts, true_ops, false_consts, false_ops = split_list( args, [true_nconsts, len(true_jaxpr.in_avals), false_nconsts]) if pred: return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops)) else: return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts): cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) def new_body_f(*vals): out = body_f(*vals) # This checks if the next cond application will error _ = cond_f(*c_consts, *out) return out return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) # Check if the first cond application will error. cond_jaxpr_, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors) cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(cond_jaxpr_)( error.err, error.code, error.payload, *c_consts, *carry) del cond_jaxpr_ checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr( cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) compat_cond_jaxpr_ = ignore_errors_jaxpr(cond_jaxpr, error) to_move = [False] * 3 + [True] * cond_nconsts + [False] * len(carry) compat_cond_jaxpr = pe.move_binders_to_front(compat_cond_jaxpr_, to_move) new_in_flat = [ *c_consts, *b_consts, cond_err, cond_code, cond_payload, *carry ] err, code, payload, *out = lax.while_p.bind(*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) new_msgs = {**error.msgs, **msgs_body, **msgs_cond} return out, Error(err, code, new_msgs, payload)
def _custom_linear_solve_impl(*args, **kwargs): const_lengths, jaxprs, tree = split_dict( kwargs, ['const_lengths', 'jaxprs', 'tree']) params, b = _split_linear_solve_args(args, const_lengths) x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b)) _check_shapes('solve', 'b', x, b, tree) return x
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry]) carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers)) const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) x_tracers = [t[0] for t in xs_tracers] x_avals = [t.aval for t in x_tracers] body_fun = jaxpr_as_fun(jaxpr) def new_body(*vals): out = body_fun(*vals) out_carry, y = split_list(out, [num_carry]) return out_carry, y new_body = callback_transform(new_body, trace.callback, strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(carry_avals + xs_avals) new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr( new_body, in_tree, tuple(carry_avals + x_avals)) vals = tuple(it.chain(new_consts, carry_vals, xs_vals)) out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length, num_consts=len(new_consts), num_carry=num_carry, jaxpr=new_jaxpr, linear=linear, unroll=unroll) return safe_map(trace.pure, out_vals)
def callback_jaxpr(closed_jaxpr, callback, strip_calls): fun = lu.wrap_init(jaxpr_as_fun(closed_jaxpr)) fun = callback_subtrace(fun) fun = _callback_fun(fun, callback, strip_calls) avals_in = closed_jaxpr.in_avals jaxpr_out, consts = cd._initial_style_jaxpr(fun, avals_in) return core.ClosedJaxpr(jaxpr_out, consts)
def body_fun(i, vals): idx = i if forward else length - i - 1 carry, ys = vals x = _index_arrays(idx, x_aval, xs) carry_out, y = core.jaxpr_as_fun(jaxpr)(consts, carry, x) ys_out = _update_arrays(idx, y_aval, ys, y) return (carry_out, ys_out)
def get_bind_params(self, params): assert 'call_jaxpr' in params assert 'transpose_jaxpr_thunk' in params new_params = dict(params) new_params['transpose'] = make_transpose_from_thunk( new_params.pop('transpose_jaxpr_thunk'), new_params['lin_tree']) call = lu.wrap_init(core.jaxpr_as_fun(new_params.pop('call_jaxpr'))) return [call], new_params
def f_aug(*args): outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args) outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs]) aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals) aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals)) return outs + list(aug_residuals)
def body_fun(i, vals): idx = i if forward else length - i - 1 carry, ys = vals x = _index_arrays(idx, x_aval, xs) cell = parametrized(jc.jaxpr_as_fun(jaxpr)) carry_out, y = cell.apply(cell_params, consts, carry, x) ys_out = _update_arrays(idx, y_aval, ys, y) return carry_out, ys_out
def body_fun(i, vals): i = i if forward else length - i - 1 carry, ys = split_list(vals, [num_carry]) x = _map(partial(_index_array, i), x_avals, xs) out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x)) carry_out, y_updates = split_list(out_flat, [num_carry]) ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates) return carry_out + ys_out
def eval_sparse( jaxpr: core.Jaxpr, consts: Sequence[Array], # all consts are dense argspecs: Sequence[ArgSpec], # mix of sparse and dense pointers into spenv spenv: SparseEnv, ) -> Sequence[ArgSpec]: env: Dict[core.Var, ArgSpec] = {} def read(var: core.Var) -> Union[Array, ArgSpec]: # all literals are dense if isinstance(var, core.Literal): return ArgSpec(np.shape(var.val), spenv.push(var.val), None) else: return env[var] def write_buffer(var: core.Var, a: Array) -> None: if var is core.dropvar: return env[var] = ArgSpec(a.shape, spenv.push(a), None) def write(var: core.Var, a: ArgSpec) -> None: if var is core.dropvar: return env[var] = a # TODO: handle unitvar at all? #write_buffer(core.unitvar, core.unit) safe_map(write_buffer, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, argspecs) for eqn in jaxpr.eqns: prim = eqn.primitive invals = safe_map(read, eqn.invars) if any(val.is_sparse() for val in invals): if prim not in sparse_rules: raise NotImplementedError(f"sparse rule for {prim}") out = sparse_rules[prim](spenv, *invals, **eqn.params) else: if prim is xla.xla_call_p: # TODO(vanderplas,frostig): workaround for binding call primitives # within a jaxpr interpreter params = eqn.params.copy() fun = lu.wrap_init( core.jaxpr_as_fun( pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) out_bufs = prim.bind(fun, *(val.data(spenv) for val in invals), **params) else: out_bufs = prim.bind(*(val.data(spenv) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf in out_bufs: out.append(ArgSpec(buf.shape, spenv.push(buf), None)) safe_map(write, eqn.outvars, out) return safe_map(read, jaxpr.outvars)
def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr): primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr, ivjp_jaxpr=ivjp_jaxpr) fun = core.jaxpr_as_fun(fun_jaxpr) # FIXME: This might compute the primals multiple times, but we only need to do # this trick while linearizing. It should be possible to do it through # a custom partial eval rule. _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents) return primals_out, tangents_out
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, main_type): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest) f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type) avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
def _jvp_jaxpr(jaxpr, nonzeros, instantiate): assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
def _xla_call_sparse(spenv, *spvalues, call_jaxpr, donated_invars, **params): if any(donated_invars): raise NotImplementedError("sparse xla_call with donated_invars") sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, pe.ClosedJaxpr(call_jaxpr, ()), *spvalues) fun = lu.wrap_init(core.jaxpr_as_fun(sp_call_jaxpr)) args_flat, _ = tree_flatten(spvalues_to_arrays(spenv, spvalues)) donated_invars = tuple(False for arg in args_flat) out_flat = xla.xla_call_p.bind(fun, *args_flat, donated_invars=donated_invars, **params) return arrays_to_spvalues(spenv, tree_unflatten(out_tree, out_flat))
def checkify_jaxpr(jaxpr, error): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f, msgs = check_errors_subtrace(f) f = check_errors_traceable(f, tuple(error.msgs.items())) err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) avals_in = [err_aval, code_aval, *jaxpr.in_avals] jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
def eval_sparse( jaxpr: core.Jaxpr, consts: Sequence[Array], # all consts are dense spvalues: Sequence[SparsifyValue], # mix of sparse and dense pointers into spenv spenv: SparsifyEnv, ) -> Sequence[SparsifyValue]: env : Dict[core.Var, SparsifyValue] = {} def read(var: core.Var) -> Union[Array, SparsifyValue]: # all literals are dense if isinstance(var, core.Literal): return spenv.dense(var.val) else: return env[var] def write_buffer(var: core.Var, a: Array) -> None: if isinstance(var, core.DropVar): return env[var] = spenv.dense(a) def write(var: core.Var, a: SparsifyValue) -> None: if isinstance(var, core.DropVar): return assert a is not None env[var] = a safe_map(write_buffer, jaxpr.constvars, consts) safe_map(write, jaxpr.invars, spvalues) for eqn in jaxpr.eqns: prim = eqn.primitive invals = safe_map(read, eqn.invars) if any(val.is_sparse() for val in invals): if prim not in sparse_rules: _raise_unimplemented_primitive(prim) out = sparse_rules[prim](spenv, *invals, **eqn.params) else: if prim is xla.xla_call_p: # TODO(vanderplas,frostig): workaround for binding call primitives # within a jaxpr interpreter params = eqn.params.copy() fun = lu.wrap_init(core.jaxpr_as_fun(pe.ClosedJaxpr(params.pop('call_jaxpr'), ()))) out_bufs = prim.bind(fun, *(spenv.data(val) for val in invals), **params) else: out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params) out_bufs = out_bufs if prim.multiple_results else [out_bufs] out = [] for buf, outvar in safe_zip(out_bufs, eqn.outvars): if isinstance(outvar, core.DropVar): out.append(None) else: out.append(spenv.dense(buf)) safe_map(write, eqn.outvars, out) return safe_map(read, jaxpr.outvars)
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear): index, *ops = args index_dim, *op_dims = dims if index_dim is not batching.not_mapped: # Convert to a lax.select. While we could get away with not broadcasting # some operands yet, because all outputs must be broadcast together anyway # for the select we broadcast the input operands for simplicity and leave # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = (batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) branches_batched = [ batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched, axis_name, main_type)[0] for jaxpr in branches ] branch_outs = [] for i, jaxpr in enumerate(branches_batched): # Perform a select on the inputs for safety of reverse-mode autodiff; see # https://github.com/google/jax/issues/1052 predicate = lax.eq(index, lax._const(index, i)) ops_ = [ _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops ] branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)] return out, [0 if b else None for b in out_batched] else: ops_bat = [d is not batching.not_mapped for d in op_dims] ops = [ batching.moveaxis(x, d, 0) if b else x for b, x, d in zip(ops_bat, ops, op_dims) ] branches_out_bat = [ batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, main_type)[1] for jaxpr in branches ] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, main_type)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] out = cond_p.bind(index, *ops, branches=branches_batched, linear=linear) return out, out_dims
def do_transpose(primals_in, cotangents_in): # NOTE: This is passing in undefined primals in place of tangent arguments, but it # should all work out, because we're only computing the primal part here. residuals = core.jaxpr_as_fun(primal_jaxpr)( *primals_in)[len(cotangents_in):] # Now that we have a purely linear jaxpr, we can transpose it cotangents_out = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, (), primals_in + residuals, cotangents_in) # backward_pass will return cotangents computed for all invars, but some of them # are residuals appended by partial eval, so we need to skip those before we return. return cotangents_out[:len(primals_in)]
def _remat_transpose(primal_jaxpr, tangent_jaxpr, reduce_axes, primals_tangents_in, cotangents_in): primals_in = [x for x in primals_tangents_in if not is_undefined_primal(x)] tangents_in = [x for x in primals_tangents_in if is_undefined_primal(x)] res = core.jaxpr_as_fun(primal_jaxpr)(*primals_in) cotangents_out_ = backward_pass(tangent_jaxpr.jaxpr, reduce_axes, False, (), (*res, *tangents_in), cotangents_in) cotangents_out = iter(cotangents_out_[len(res):]) outs = [next(cotangents_out) if is_undefined_primal(x) else Zero.from_value(x) for x in primals_tangents_in] assert next(cotangents_out, None) is None return outs
def _interpret_jaxpr(jaxpr: core.TypedJaxpr, *args: TfValOrUnit) -> Sequence[TfVal]: """Evaluates a Jaxpr with tf.Tensor arguments. It is safe to call this function with arguments TfVal or TfValOrUnit, they will be replaced with `core.unit` if the `jaxpr` expects units. The output is a sequence of TfVal (no `core.unit`), suitable for use with TF. """ fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) args_jax: Sequence[TfValOrUnit] = _tfval_add_unit(args, jaxpr.in_avals) out_vals_jax: Sequence[TfValOrUnit] = _interpret_fun(fun, args_jax) return _tfval_remove_unit(out_vals_jax)
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): fun = core.jaxpr_as_fun(jaxpr) @lu.wrap_init def masked(*args): [dynamic_length], consts, [i], carry, xs = split_list( args, [1, num_consts, 1, num_carry]) out = fun(*(consts + carry + xs)) new_carry, ys = split_list(out, [num_carry]) new_carry = [lax.select(i < dynamic_length, new_c, c) for new_c, c in zip(new_carry, carry)] return [i + 1] + new_carry + ys aval = ShapedArray((), onp.int64) const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
def _get_harvest_metadata(closed_jaxpr, settings, *args): """Probes a jaxpr for metadata like its sown values.""" fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr)) with jax_core.new_main(HarvestTrace) as main: settings = HarvestSettings(settings.tag, settings.blocklist, settings.allowlist, True) fun = reap_function(fun, main, settings, True) fun, aux = _reap_metadata_wrapper(fun) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) in_avals = jax_util.safe_map( lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)), flat_args) pe.trace_to_jaxpr_final(flat_fun, in_avals) metadata = aux() out_tree() return metadata