def _prune_unused_inputs( jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]: used = {v for v in jaxpr.outvars if isinstance(v, core.Var)} # TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive # applications that do not produce used outputs. Must handle side-effecting # primitives and nested jaxpr. used.update( v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var)) kept_const_idx, new_constvars = util.unzip2( (i, v) for i, v in enumerate(jaxpr.constvars) if v in used) kept_var_idx, new_invars = util.unzip2( (i, v) for i, v in enumerate(jaxpr.invars) if v in used) new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns) return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
def post_process_call(self, call_primitive, out_tracers, params): vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers) main = self.main def todo(vals): trace = MaskTrace(main, core.cur_sublevel()) return map(partial(MaskTracer, trace), vals, shapes) return vals, todo
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): assert call_primitive.multiple_results primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) nonzero_tangents, tangent_tree_def = tree_flatten(tangents) nz_tangents = [type(t) is not Zero for t in tangents] if 'name' in params and not config.jax_experimental_name_stack: params = dict(params, name=wrap_name(params['name'], 'jvp')) f_jvp = jvp_subtrace(f, self.main) f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp) if isinstance(call_primitive, core.MapPrimitive): in_axes = params['in_axes'] tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz] out_axes_thunk = params['out_axes_thunk'] # The new thunk depends deterministically on the old thunk and the wrapped function. # Any caching already has to include the wrapped function as part of the key, so we # only use the previous thunk for equality checks. # NOTE: This assumes that the output tangents being zero is a deterministic # function of which input tangents were zero. @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk)) def new_out_axes_thunk(): out_axes = out_axes_thunk() return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz)) params = dict(params, in_axes=(*in_axes, *tangent_in_axes), out_axes_thunk=new_out_axes_thunk) f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def) update_params = call_param_updaters.get(call_primitive) new_params = update_params(params, nz_tangents) if update_params else params result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) primal_out, tangent_out = tree_unflatten(out_tree_def(), result) return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
def process_custom_transpose(self, prim, call, tracers, **params): ps_in, ts_in = unzip2((t.primal, t.tangent) for t in tracers) res_ps_in, lin_ps_in = split_list(ps_in, [params['res_tree'].num_leaves]) res_ts_in, lin_ts_in = split_list(ts_in, [params['res_tree'].num_leaves]) # TODO(frostig): Handle differentiation with respect to residual # operands. Calling `call` twice on all operands invalid, since it # isn't linear in the residuals. However, we know that if we # write: # # jvp_call_res = lambda x: partial(jvp, lambda r: call(r, x)) # # then: # # jvp(call, (r, x), (dr, dx)) == jvp_call_res(x)(r, dr) + call(r, dx) # # In words: a possible strategy is to take the jvp of `call` with # respect to residuals, and with linear arguments fixed, then add # that to a custom-transpose call to `call` (i.e. what we already # do below in the all-linear argument case). if any(type(t) is not Zero for t in res_ts_in): raise NotImplementedError( 'JVP of custom transpose with respect to non-symbolic-zero residuals') ps_out = prim.bind(call, *ps_in, **params) lin_ts_in = map(instantiate_zeros, lin_ts_in) ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params) return map(partial(JVPTracer, self), ps_out, ts_out)
def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] from_dtypes = [dtypes.dtype(l) for l in lst] to_dtype = dtypes.result_type(*from_dtypes) sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) indices = np.cumsum(sizes) if all(dt == to_dtype for dt in from_dtypes): # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. # See https://github.com/google/jax/issues/7809. del from_dtypes, to_dtype def unravel(arr): chunks = jnp.split(arr, indices[:-1]) return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)] raveled = jnp.concatenate([jnp.ravel(e) for e in lst]) return raveled, unravel # When there is more than one distinct input dtype, we perform type # conversions and produce a dtype-specific unravel function. def unravel(arr): arr_dtype = dtypes.dtype(arr) if arr_dtype != to_dtype: raise TypeError(f"unravel function given array of dtype {arr_dtype}, " f"but expected dtype {to_dtype}") chunks = jnp.split(arr, indices[:-1]) with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore complex-to-real cast warning return [lax.convert_element_type(chunk.reshape(shape), dtype) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)] ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([ravel(e) for e in lst]) return raveled, unravel
def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars policy = params['policy'] or (lambda *_, **__: False) # unzip into jaxpr_known and jaxpr_unknown in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \ pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy) jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars)) _, used_outs_unknown = partition_list(out_inst, out_unknowns) jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown) # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) out_consts_ = iter(out_consts) # form known outputs and collect residual tracers out_known_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None) for uk in out_unknowns if not uk] residuals = list(out_consts_) # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)] _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers) out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, source_info_util.current()) for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results if config.jax_experimental_name_stack: params = dict(params, name=params.get('name', f.__name__)) else: params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap')) vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) else: f_, dims_out = batch_subtrace(f, self.main, dims) ax_size, = { x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped } f_ = _update_annotation(f_, f.in_type, ax_size, self.axis_name, dims) vals_out = call_primitive.bind(f_, *vals, **params) src = source_info_util.current() return [ BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out()) ]
def to_jaxpr(self, in_dim_tracers, in_tracers, out_dim_tracers, out_tracers): t2v = lambda t: self.tracer_to_var[id(t)] in_dim_binders, in_binders = map(t2v, in_dim_tracers), map(t2v, in_tracers) out_dims, outs = map(t2v, out_dim_tracers), map(t2v, out_tracers) # only include constants that are used used_vars = ({a for eqn in self.eqns for a in eqn.invars if isinstance(a, Var)} | {a for grp in [out_dims, outs] for a in grp if isinstance(a, Var)}) constvars, constvals = unzip2( (v, c) for v, c in self.constvar_to_val.items() if v in used_vars) in_binders = [*constvars, *in_binders] # promote some lambda binders to pi binders used_shape_vars = ({d for eqn in self.eqns for v in eqn.outvars if isinstance(v.aval, AbsArray) for d in v.aval.shape if isinstance(d, Var)} | {d.name for eqn in self.eqns for v in eqn.outvars if isinstance(v.aval, AbsArray) for d in v.aval.shape if isinstance(d, DimIndexingExpr)}) lambda_binders = [v not in used_shape_vars for v in in_binders] converted_binders, in_binders = partition_list(lambda_binders, in_binders) in_dim_binders = in_dim_binders + converted_binders out_dims = [v for v in out_dims if v not in in_dim_binders] # TODO jaxpr = DJaxpr(in_dim_binders, in_binders, out_dims, outs, self.eqns) typecheck_jaxpr(jaxpr) return jaxpr, constvals, lambda_binders
def testOpShardingRoundTrip(self): FakeDevice = namedtuple('FakeDevice', ['id']) mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)]) mesh_axes, mesh_shape = unzip2(mesh_named_shape.items()) devices = [FakeDevice(i) for i in range(np.prod(list(mesh_shape)))] mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes)) dims = 5 aval = jax.core.ShapedArray((len(devices),) * dims, jnp.float32) def roundtrip(spec): op_sharding = pjit_lib.get_aval_sharding_proto(aval, spec, mesh) parsed_spec = pjit_lib.parse_op_sharding(op_sharding, mesh).partitions self.assertEqual(parsed_spec[:len(spec)], spec) self.assertEqual(parsed_spec[len(spec):], ((),) * (len(parsed_spec) - len(spec))) special_specs = [P()] for spec in special_specs: roundtrip(spec) rng = np.random.default_rng(1) for i in range(100): spec = [()] * dims for axis in rng.permutation(mesh_axes)[:rng.integers(low=1, high=len(mesh_axes) + 1)]: spec[rng.choice(dims)] += (axis,) roundtrip(P(*spec))
def post_process_map(self, call_primitive, out_tracers, params): vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) main = self.main def both_mapped(in_out_axis, d): return in_out_axis is not None and d is not not_mapped def todo(vals): trace = main.with_cur_sublevel() return [ BatchTracer( trace, v, d + 1 if both_mapped(out_axis, d) and out_axis <= d else d) for v, d, out_axis in zip(vals, dims, params['out_axes_thunk'] ()) ] if call_primitive.map_primitive: def out_axes_transform(out_axes): return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis for out_axis, d in zip(out_axes, dims)) todo = (todo, out_axes_transform) return vals, todo
def jet_subtrace(main, primals, series): trace = JetTrace(main, core.cur_sublevel()) in_tracers = map(partial(JetTracer, trace), primals, series) ans = yield in_tracers, {} out_tracers = map(trace.full_raise, ans) out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) yield out_primals, out_terms
def cond_error_check(error, index, *ops, branches, linear): new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches) new_linear = (False, False, *linear) err, code, *outs = control_flow.cond_p.bind( index, error.err, error.code, *ops, branches=tuple(new_branches), linear=new_linear) new_msgs = {k:v for d in it.chain([error.msgs], msgs_) for k, v in d.items()} return outs, Error(err, code, new_msgs)
def doubling_subtrace(main, heads, tails): trace = DoublingTrace(main, core.cur_sublevel()) in_tracers = [DoublingTracer(trace, h, t) if t is not None else h for h, t in zip(heads, tails)] ans = yield in_tracers, {} out_tracers = map(trace.full_raise, ans) yield unzip2([(out_tracer.head, out_tracer.tail) for out_tracer in out_tracers])
def new_f(*args, **kwargs): axis_names, shape = unzip2(named_shape) size = np.prod(shape) local_devices = list(jax.local_devices()) if len(local_devices) < size: raise SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) with mesh(mesh_devices, axis_names): return f(*args, **kwargs)
def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results params = dict(params, name=wrap_name(params.get('name', f.__name__), 'mask')) vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers) if not any(is_polymorphic(s) for s in shapes): return call_primitive.bind(f, *vals, **params) else: logical_env, padded_env = shape_envs env_keys, padded_env_vals = unzip2(sorted(padded_env.items())) logical_env_vals = tuple(logical_env[k] for k in env_keys) # Make padded_env hashable padded_env = (env_keys, padded_env_vals) f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env) if 'donated_invars' in params: params = dict(params, donated_invars=((False,) * len(logical_env_vals) + params['donated_invars'])) vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params) return [MaskTracer(self, v, s) for v, s in zip(vals_out, shapes_out())]
def post_process_custom_jvp_call(self, out_tracers, params): vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) main = self.main def todo(vals): trace = main.with_cur_sublevel() return map(partial(BatchTracer, trace), vals, dims) return vals, todo
def batch_subtrace(main, in_dims, *in_vals): # used in e.g. process_call trace = main.with_cur_sublevel() in_dims = in_dims() if callable(in_dims) else in_dims in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) if dim is not None else x for x, dim in zip(in_vals, in_dims)] outs = yield in_tracers, {} out_tracers = map(trace.full_raise, outs) out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) yield out_vals, out_dims
def process_custom_jvp_call(self, prim, fun, jvp, tracers): in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims) out_vals = prim.bind(fun, jvp, *in_vals) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) if not fst: assert out_dims == out_dims[:len(out_dims) // 2] * 2 out_dims = out_dims[:len(out_dims) // 2] return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
def post_process_call(self, call_primitive, out_tracers, params): primals, series = unzip2((t.primal, t.terms) for t in out_tracers) out, treedef = tree_flatten((primals, series)) del primals, series main = self.main def todo(x): primals, series = tree_unflatten(treedef, x) trace = JetTrace(main, core.cur_sublevel()) return map(partial(JetTracer, trace), primals, series) return out, todo
def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes): env_keys, padded_env_vals = unzip2(sorted(padded_env.items())) logical_env_vals = [logical_env[k] for k in env_keys] # Make padded_env hashable padded_env = (env_keys, padded_env_vals) with core.new_main(MaskTrace) as main: fun, out_shapes = mask_subtrace(fun, main, polymorphic_shapes, padded_env) out_vals = fun.call_wrapped(*(logical_env_vals + in_vals)) del main return out_vals, out_shapes()
def process_call(self, call_primitive, f, tracers, params): primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) update_params = call_param_updaters.get(call_primitive) new_params = (update_params(params, len(primals_and_series)) if update_params else params) result = call_primitive.bind(f_jet, *primals_and_series, **new_params) primals_out, series_out = tree_unflatten(out_tree_def(), result) return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params): vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(dim is not_mapped for dim in dims): return map_primitive.bind(f, *vals, **params) else: assert len({ x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped }) == 1 # The logic for the dimension math below is as follows: # ╔═════════════╦════════════════════════════════════════╦═══════════╗ # ║ d / in_axis ║ None ║ int ║ # ╠═════════════╬════════════════════════════════════════╩═══════════╣ # ║ None ║ No extra axis, so in_axis unaffected ║ # ╠═════════════╬════════════════════════════════════════╦═══════════╣ # ║ int ║ Not mapped, so batching dim unaffected ║ See below ║ # ╚═════════════╩════════════════════════════════════════╩═══════════╝ # When both d and in_axis are defined then: # - If `d <= in_axis`, we have to move the `in_axis` one dimension further; # - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed). def both_mapped(in_out_axis, d): return in_out_axis is not None and d is not not_mapped new_in_axes = tuple( in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis for d, in_axis in zip(dims, params['in_axes'])) new_dims = tuple( d - 1 if both_mapped(in_axis, d) and in_axis < d else d for d, in_axis in zip(dims, params['in_axes'])) f, dims_out = batch_subtrace(f, self.main, new_dims) out_axes_thunk = params['out_axes_thunk'] # NOTE: This assumes that the choice of the dimensions over which outputs # are batched is entirely dependent on the function and not e.g. on the # data or its shapes. @as_hashable_function(closure=out_axes_thunk) def new_out_axes_thunk(): return tuple( out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis for out_axis, d in zip(out_axes_thunk(), dims_out())) new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk) vals_out = map_primitive.bind(f, *vals, **new_params) dims_out = (d + 1 if both_mapped(out_axis, d) and out_axis <= d else d for d, out_axis in zip(dims_out(), out_axes_thunk())) src = source_info_util.current() return [ BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out) ]
def jvp_subtrace(main, primals, tangents): trace = JVPTrace(main, core.cur_sublevel()) for x in list(primals) + list(tangents): if isinstance(x, Tracer): assert x._trace.level < trace.level in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x for x, t in zip(primals, tangents)] ans = yield in_tracers, {} out_tracers = map(trace.full_raise, ans) yield unzip2([(out_tracer.primal, out_tracer.tangent) for out_tracer in out_tracers])
def _squeeze_lowering(ctx, x, dimensions): in_aval, = ctx.avals_in out_aval, = ctx.avals_out if not out_aval.shape: return Idx(x, (unitIdx,)) idx_names, idx_tys = unzip2((ctx.fresh('i'), FinType(Literal(sz))) for sz in out_aval.shape) idx_name = iter(idx_names) idxs = [unitIdx if dim in dimensions else Var(next(idx_name)) for dim in range(in_aval.ndim)] return For(tuple(idx_names), tuple(idx_tys), Idx(x, tuple(idxs)))
def process_primitive(self, primitive, tracers, params): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) jvp = primitive_jvps.get(primitive) if not jvp: msg = f"Differentiation rule for '{primitive}' not implemented" raise NotImplementedError(msg) primal_out, tangent_out = jvp(primals_in, tangents_in, **params) if primitive.multiple_results: return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)] else: return JVPTracer(self, primal_out, tangent_out)
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: """Test utility for setting up meshes given mesh data from `schedules`.""" # This is similar to the `with_mesh` function above, but isn't a decorator. axis_names, shape = unzip2(named_shape) size = prod(shape) local_devices = list(jax.local_devices()) if len(local_devices) < size: raise SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) with mesh(mesh_devices, axis_names): yield
def process_custom_jvp_call(self, _, __, f_jvp, tracers): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) primals_in = map(core.full_lower, primals_in) tangents_in = map(instantiate_zeros, tangents_in) # Cast float0 to zeros with the primal dtype because custom jvp rules don't # currently handle float0s tangents_in = map(replace_float0s, primals_in, tangents_in) outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out)
def _broadcasting_binop(binop_expr: Expr, ctx, x, y): x_aval, y_aval = ctx.avals_in out_aval, = ctx.avals_out if not out_aval.shape: return App(App(binop_expr, x), y) idx_names, idx_tys = unzip2((ctx.fresh('i'), FinType(Literal(sz))) for sz in out_aval.shape) x_expr = _make_bcast_expr(idx_names, out_aval.shape, x_aval.shape, x) y_expr = _make_bcast_expr(idx_names, out_aval.shape, y_aval.shape, y) out = For(tuple(idx_names), tuple(idx_tys), App(App(binop_expr, x_expr), y_expr)) return out
def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) tangents_in = map(instantiate_zeros, tangents_in) res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) out_tree, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out)
def jvp_subtrace_aux(main, primals, tangents): trace = JVPTrace(main, core.cur_sublevel()) for x in list(primals) + list(tangents): if isinstance(x, Tracer): assert x._trace.level < trace.level ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {} ans_tracers = map(trace.full_raise, ans) out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers) aux_primals = [core.full_lower(x.primal) if isinstance(x, JVPTracer) and x._trace.level == trace.level else x for x in aux] yield (out_primals, out_tangents), aux_primals