def unzip_eval_wrapper(pvs, *consts): """Function transformation that returns init/apply jaxprs and metadata.""" args = (safe_map(pe.PartialVal, safe_zip(pvs, consts)),) success, result = yield args, {} if success: init_out, apply_out, pvals, metadata = result init_jaxpr, init_consts, init_env = init_out apply_jaxpr, apply_consts, apply_env = apply_out init_pvals, apply_pvals = pvals init_pvs, init_pv_consts = jax_util.unzip2(init_pvals) apply_pvs, apply_pv_consts = jax_util.unzip2(apply_pvals) out = ( tuple(init_pv_consts) + tuple(init_consts) + tuple(apply_pv_consts) + tuple(apply_consts)) yield out, (success, len(out), ((init_pvs, len(init_consts), apply_pvs), (init_jaxpr, apply_jaxpr), (init_env, apply_env), metadata)) else: jaxpr, (out_pvals, out_keys, consts, env) = result out_pvs, out_consts = jax_util.unzip2(out_pvals) out = tuple(out_consts) + tuple(consts) yield out, (success, len(out), (out_pvs, out_keys, jaxpr, env))
def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) out_pvals = [t.pval for t in out_tracers] # TODO: this is from partial_eval.trace_to_jaxpr. Share. assert not env # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple( abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts) in_pvals = [t.pval for t in in_tracers] in_avals = tuple( safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0])) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts
def dot_general_dependency_rule(outstart, outcount, lhs, rhs, dimension_numbers, precision): if not is_ones(outcount): raise NotImplementedError outshape = outcount.shape outslices = list(zip(outstart, outshape)) (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers lhs_other_out_dims = list( range(len(lhs_batch), len(lhs.shape) - len(lhs_contracting))) rhs_other_out_dims = list( range(len(rhs_batch) + len(lhs_other_out_dims), len(outshape))) lhs_outstart, lhs_outshape = unzip2( [outslices[d] for d in list(lhs_batch) + lhs_other_out_dims]) (lhs_box, ), (lhs_count, ), _ = reduce_dependency_rule(None)( lhs_outstart, Ones(lhs_outshape), lhs, axes=lhs_contracting) rhs_outstart, rhs_outshape = unzip2( [outslices[d] for d in list(rhs_batch) + rhs_other_out_dims]) (rhs_box, ), (rhs_count, ), _ = reduce_dependency_rule(None)( rhs_outstart, Ones(rhs_outshape), rhs, axes=rhs_contracting) incounts = [ materialize(lhs_count) * prod(np.take(outshape, rhs_other_out_dims)) if isinstance(lhs, LazyArray) else None, materialize(rhs_count) * prod(np.take(outshape, lhs_other_out_dims)) if isinstance(rhs, LazyArray) else None ] return ([lhs_box, rhs_box], incounts, lambda *inslices: lax.dot_general( *inslices, dimension_numbers, precision))
def split_tracers_and_nontracers(jaxpr, consts): tracer = [] nontracer = [] for x in zip(jaxpr.constvars, consts): # TODO(phawkins): We avoid treating DeviceArrays as constant literals so # we don't copy large arrays back to the host. We probably should relax # this and either always copy small constants, or opportunistically use # DeviceArray values for which we already know npy_value. not_literal_const = isinstance(x[1], (core.Tracer, xla.DeviceArray)) (tracer if not_literal_const else nontracer).append(x) tracer_vars, tracer_consts = unzip2(tracer) nontracer_vars, nontracer_consts = unzip2(nontracer) return nontracer_vars + tracer_vars, nontracer_consts, tracer_consts
def process_call(self, call_primitive, f, tracers, params): primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) f_jet, out_tree_def = traceable(jet_subtrace(f, self.master), in_tree_def) result = call_primitive.bind(f_jet, *primals_and_series, **params) primals_out, series_out = tree_unflatten(out_tree_def(), result) return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
def wrapped(*args, **kwargs): """Function wrapper that takes in inverse arguments.""" forward_args = trace_args if len(trace_args) else args jaxpr, (in_tree, _) = trace_util.stage(f, dynamic=False)(*forward_args, **kwargs) flat_forward_args, _ = tree_util.tree_flatten(forward_args) flat_args, _ = tree_util.tree_flatten(args) flat_constcells = safe_map(InverseAndILDJ.new, jaxpr.literals) flat_forward_avals = [ trace_util.get_shaped_aval(arg) for arg in flat_forward_args ] flat_incells = [ InverseAndILDJ.unknown(aval) for aval in flat_forward_avals ] flat_outcells = safe_map(InverseAndILDJ.new, flat_args) env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr, flat_constcells, flat_incells, flat_outcells) flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars] if any(not flat_incell.top() for flat_incell in flat_incells): raise ValueError('Cannot invert function.') flat_vals, flat_ildjs = jax_util.unzip2([ (flat_incell.val, flat_incell.ildj) for flat_incell in flat_incells ]) vals = tree_util.tree_unflatten(in_tree, flat_vals) if reduce_ildj: ildj_ = sum(np.sum(i) for i in flat_ildjs) else: ildj_ = tree_util.tree_unflatten(in_tree, flat_ildjs) if len(forward_args) == 1: vals = vals[0] ildj_ = ildj_ if reduce_ildj else ildj_[0] return vals, ildj_
def testScanVmapTuples(self): def f(c, a): a1, a2 = a c1, c2 = c b = np.sum(np.cos(a1)) * np.sum(np.tan(c2 * a2)) c = c1 * np.sin(np.sum(a1 * a2)), c2 * np.cos(np.sum(a1)) return c, b in_axes = (0, (1, 2)) r = onp.random.RandomState(0) as_ = (r.randn(3, 7), r.randn(3, 4, 7)) c = (r.randn(7, 2), r.randn(7)) expected_c_out, expected_bs = [], [] for i in range(7): c_out, bs = lax.scan(f, (c[0][i], c[1][i]), (as_[0][:, i], as_[1][:, :, i])) expected_c_out.append(c_out) expected_bs.append(bs) expected_c_out_0, expected_c_out_1 = unzip2(expected_c_out) expected_c_out = (np.stack(expected_c_out_0), np.stack(expected_c_out_1)) expected_bs = np.stack(expected_bs) expected = expected_c_out, expected_bs ans = api.vmap(lambda c, as_: lax.scan(f, c, as_), in_axes)(c, as_) self.assertAllClose(ans, expected, check_dtypes=False)
def custom_layer_cau_batch(trace, f, tracers, params): """Batching rule for layer_cau primitive to handle custom layers.""" vals, dims = jax_util.unzip2((t.val, t.batch_dim) for t in tracers) if all(dim is batching.not_mapped for dim in dims): return layer_cau_p.bind(f, *vals, **params) args = tree_util.tree_unflatten(params['in_tree'], vals) dims_ = [not_mapped if dim is None else dim for dim in dims] layer, args = args[0], args[1:] if hasattr(layer, '_call_and_update_batched'): num_params = len(tree_util.tree_leaves(layer)) layer_dims, arg_dims = dims_[:num_params], dims_[num_params:] if params['kwargs']['has_rng']: rng, args = args[0], args[1:] rng_dim, arg_dims = arg_dims[0], arg_dims[1:] mapping_over_layer = all(layer_dim is not not_mapped for layer_dim in layer_dims) mapping_over_args = all(arg_dim is not not_mapped for arg_dim in arg_dims) assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims) if not mapping_over_layer and mapping_over_args: if params['kwargs']['has_rng']: if rng_dim is not not_mapped: arg_dims = tuple(None if dim is not_mapped else dim for dim in arg_dims) map_fun = jax.vmap( lambda layer, rng, *args: _layer_cau_batched( layer, rng, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **params['kwargs']), in_axes=(None, rng_dim) + (None, ) * len(arg_dims)) else: map_fun = lambda layer, *args: _layer_cau_batched( layer, *args, # pylint: disable=unnecessary-lambda, g-long-lambda **params['kwargs']) vals_out, update_out = map_fun(layer, rng, *args) else: vals_out, update_out = _layer_cau_batched( layer, *args, **params['kwargs']) vals_out = tree_util.tree_leaves(vals_out) update_out = tree_util.tree_leaves(update_out) assert all(dim == 0 for dim in arg_dims) # Assume dimensions out are consistent dims_out = (0, ) * len(vals_out) dims_update = (None, ) * len(update_out) dims_out = dims_out + dims_update # Call wrapped function to avoid linear_util error f.call_wrapped(*tracers) return [ batching.BatchTracer(trace, v, d) for v, d in zip(vals_out + update_out, dims_out + dims_update) ] f, dims_out = batching.batch_subtrace(f, trace.master, dims) vals_out = layer_cau_p.subcall('batch').bind(f, *vals, **params) return [ batching.BatchTracer(trace, v, d) for v, d in zip(vals_out, dims_out()) ]
def jet_subtrace(main, primals, series): trace = JetTrace(main, core.cur_sublevel()) in_tracers = map(partial(JetTracer, trace), primals, series) ans = yield in_tracers, {} out_tracers = map(trace.full_raise, ans) out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) yield out_primals, out_terms
def _make_typed_jaxpr(traceable, in_avals): pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable, pvals, instantiate=True) out_avals, _ = unzip2(pvals_out) return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
def jet_transform(primals, series): with core.new_master(JetTrace) as master: trace = JetTrace(master, core.cur_sublevel()) in_tracers = map(partial(JetTracer, trace), primals, series) ans = yield in_tracers, {} out_tracers = map(trace.full_raise, ans) out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) yield out_primals, out_terms
def doubling_subtrace(main, heads, tails): trace = DoublingTrace(main, core.cur_sublevel()) in_tracers = [DoublingTracer(trace, h, t) if t is not None else h for h, t in zip(heads, tails)] ans = yield in_tracers, {} out_tracers = map(trace.full_raise, ans) yield unzip2([(out_tracer.head, out_tracer.tail) for out_tracer in out_tracers])
def _initial_style_jaxpr(fun, in_tree, in_avals): in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True) out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts) typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts, out_tree()
def new_f(*args, **kwargs): axis_names, shape = unzip2(named_shape) size = np.prod(shape) local_devices = list(jax.local_devices()) if len(local_devices) < size: raise SkipTest(f"Test requires {size} local devices") mesh_devices = np.array(local_devices[:size]).reshape(shape) with mesh(mesh_devices, axis_names): return f(*args, **kwargs)
def _scan_partial_eval(trace, *tracers, **kwargs): jaxpr = kwargs.pop('jaxpr') length = kwargs.pop('length') forward = kwargs.pop('forward') assert not kwargs in_pvs, _ = unzip2([t.pval for t in tracers]) sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs) sc_carry = sc_init for i in range(1000): second_components = (sc_consts, sc_carry, sc_xs) jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr, second_components, instantiate=(sc_carry, False)) sc_carry_out, sc_ys = sc_out if sc_carry_out == sc_carry: break else: sc_carry = _binary_lattice_join(sc_carry, sc_carry_out) else: raise FixedPointError consts_tracer, init_tracer, xs_tracer = tracers lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry) lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers]) carry_aval, y_aval = jaxpr.out_aval ys_aval = _promote_aval_rank(length, y_aval) out_aval = core.AbstractTuple((carry_aval, ys_aval)) out_pv = _put_known_pvs(sc_out, out_aval) out_carry, (ys, residuals) = scan_p.bind(*in_consts, forward=forward, length=length, jaxpr=jaxpr_1) out_const = core.pack((out_carry, ys)) residuals_tracer = trace.new_instantiated_const(core.pack(residuals)) d, c, a = lifted_tracers new_tracers = (d, c, (a, residuals_tracer)) eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False, dict(forward=forward, length=length, jaxpr=jaxpr_2)) return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
def post_process_call(self, call_primitive, out_tracers, params): primals, series = unzip2((t.primal, t.terms) for t in out_tracers) out, treedef = tree_flatten((primals, series)) del primals, series master = self.master def todo(x): primals, series = tree_unflatten(treedef, x) trace = JetTrace(master, core.cur_sublevel()) return map(partial(JetTracer, trace), primals, series) return out, todo
def process_call(self, call_primitive, f, tracers, params): if call_primitive in pe.map_primitives: raise NotImplementedError vals, net_params = unzip2((t.val, t.net_params) for t in tracers) if any(net_params): net_params = merge_params(net_params) f = apply_subtrace(f, self.master, WrapHashably(net_params)) val_out = call_primitive.bind(f, *vals, **params) return ApplyTracer(self, net_params, val_out) else: return call_primitive.bind(f, *vals, **params)
def process_primitive(self, primitive, tracers, params): vals_in, net_params = unzip2((t.val, t.net_params) for t in tracers) net_params = merge_params(net_params) if isinstance(primitive, Layer): apply_fun = primitive.apply_fun layer_params = net_params[primitive.name] return ApplyTracer(self, net_params, apply_fun(layer_params, *vals_in)) else: return ApplyTracer(self, net_params, primitive.bind(*vals_in, **params))
def tree_update(i, grad_tree, opt_state): states_flat, tree, subtrees = opt_state grad_flat = flatten(grad_tree) # if tree2 != tree: # msg = ("optimizer update function was passed a gradient tree that did " # "not match the parameter tree structure with which it was " # "initialized: parameter tree {} and grad tree {}.") # raise TypeError(msg.format(tree, tree2)) states = map(pack_sequence_as, subtrees, states_flat) states_ = [] for i in range(len(states[0])): states_.append((states[0][i], states[1][i])) new_states = map(partial(update, i), grad_flat, states_) new_states_flat = unzip2(map(flatten, new_states)) # for subtree, subtree2 in zip(subtrees, subtrees2): # if subtree2 != subtree: # msg = ("optimizer update function produced an output structure that " # "did not match its input structure: input {} and output {}.") # raise TypeError(msg.format(subtree, subtree2)) return OptimizerState(new_states_flat, tree, unzip2(new_states))
def process_primitive(self, primitive, tracers, params): primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) order, = {len(terms) for terms in series_in if terms is not zero_series} series_in = [[zero_term] * order if s is zero_series else s for s in series_in] # TODO(mattjj): avoid always instantiating zeros series_in = [[onp.zeros(onp.shape(x), dtype=onp.result_type(x)) if t is zero_term else t for t in series] for x, series in zip(primals_in, series_in)] rule = jet_rules[primitive] primal_out, terms_out = rule(primals_in, series_in, **params) return JetTracer(self, primal_out, terms_out)
def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results heads, tails = unzip2((t.head, t.tail) for t in tracers) nonzero_tails, in_tree_def = tree_flatten(tails) f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main), len(heads), in_tree_def) name = params.get('name', f.__name__) new_params = dict(params, name=wrap_name(name, 'doubledouble'), donated_invars=(False,) * (len(heads) + len(nonzero_tails))) result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params) heads_out, tails_out = tree_unflatten(out_tree_def(), result) return [DoublingTracer(self, h, t) for h, t in zip(heads_out, tails_out)]
def _ppermute_translation_rule(c, x, replica_groups, perm, platform=None): group_size = len(replica_groups[0]) srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm) if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))): msg = "ppermute sources and destinations must be unique, got {}." raise ValueError(msg.format(perm)) full_perm = [] for grp in replica_groups: grp = list(sorted(grp)) full_perm.extend((grp[src], grp[dst]) for src, dst in perm) return xops.CollectivePermute(x, full_perm)
def process_primitive(self, primitive, tracers, params): assert not primitive.multiple_results # TODO order = self.master.order # pytype: disable=attribute-error primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) series_in = [[zero_term] * order if s is zero_series else s for s in series_in] # TODO(mattjj): avoid always instantiating zeros series_in = [[np.zeros(np.shape(x), dtype=np.result_type(x)) if t is zero_term else t for t in series] for x, series in zip(primals_in, series_in)] rule = jet_rules[primitive] primal_out, terms_out = rule(primals_in, series_in, **params) return JetTracer(self, primal_out, terms_out)
def process_call(self, call_primitive, f, tracers, params): primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) f_jet, out_tree_def = traceable(jet_subtrace(f, self.master), in_tree_def) new_params = dict(params) if "donated_invars" in params: if any(params["donated_invars"]): raise ValueError("Buffer donation is not supported with jet.") new_donated_invars = (False,) * len(primals_and_series) new_params["donated_invars"] = new_donated_invars result = call_primitive.bind(f_jet, *primals_and_series, **new_params) primals_out, series_out = tree_unflatten(out_tree_def(), result) return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
def process_call(self, call_primitive, f, tracers, params): primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) update_params = call_param_updaters.get(call_primitive) new_params = (update_params(params, len(primals_and_series)) if update_params else params) result = call_primitive.bind(f_jet, *primals_and_series, **new_params) primals_out, series_out = tree_unflatten(out_tree_def(), result) return [ JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out) ]
def _ppermute_translation_rule(c, x, device_groups, perm): group_size = len(device_groups[0]) if not all(0 <= i < group_size and 0 <= j < group_size for i, j in perm): msg = ( "ppermute permutation elements must take on values between 0 and " "the group size {}, but got {}.") raise ValueError(msg.format(group_size, perm)) sources, dests = unzip2(perm) if not (len(sources) == len(set(sources)) and len(dests) == len(set(dests))): msg = "ppermute sources and destinations must be unique, got {}." raise ValueError(msg.format(perm)) full_perm = [] for grp in device_groups: grp = list(sorted(grp)) full_perm.extend((grp[src], grp[dst]) for src, dst in perm) return c.CollectivePermute(x, full_perm)
def tree_update(i, grad_tree, opt_state): packed_state, tree, subtrees = opt_state grad_flat, tree2 = tree_flatten(grad_tree) if tree2 != tree: msg = ("optimizer update function was passed a gradient tree that did " "not match the parameter tree structure with which it was " "initialized: parameter tree {} and grad tree {}.") raise TypeError(msg.format(tree, tree2)) states = map(tree_unflatten, subtrees, packed_state) new_states = map(partial(update, i), grad_flat, states) new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states)) for subtree, subtree2 in zip(subtrees, subtrees2): if subtree2 != subtree: msg = ("optimizer update function produced an output structure that " "did not match its input structure: input {} and output {}.") raise TypeError(msg.format(subtree, subtree2)) new_packed_state = pack(map(pack, new_states_flat)) return OptimizerState(new_packed_state, tree, subtrees)
def pack_optimizer_state(marked_pytree): """Converts a marked pytree to an OptimizerState. The inverse of unpack_optimizer_state. Converts a marked pytree with the leaves of the outer pytree represented as JoinPoints back into an OptimizerState. This function is intended to be useful when deserializing optimizer states. Args: marked_pytree: A pytree containing JoinPoint leaves that hold more pytrees. Returns: An equivalent OptimizerState to the input argument. """ sentinels, tree_def = tree_flatten(marked_pytree) assert all(isinstance(s, JoinPoint) for s in sentinels) subtrees = [s.subtree for s in sentinels] states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees)) return OptimizerState(states_flat, tree_def, subtree_defs)
def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts): new_pv, new_const = new_pval old_pv, old_const = old_pval if new_pv == old_pv: # we didn't move up the lattice by joining with the other side return jaxpr, consts elif old_pv is None: # we moved up the lattice from totally-known, so make a new jaxpr that # returns a single constant JaxTuple with elements that are constants # drawn from consts where new_pv is unknown assert not jaxpr.eqns and not consts outvar = pe.Var(0, "_cond") new_jaxpr = jaxpr.copy() new_jaxpr.constvars = [outvar] new_jaxpr.outvar = outvar new_consts = (core.pack([ core.unit if pv is None else old_c for pv, old_c in zip(new_pv, old_const) ]), ) return new_jaxpr, new_consts else: # we moved up the lattice, but not from totally-constant, so adapt the # japxr to return some new constants in places that are now unknown but # weren't before eqn = jaxpr.eqns[-1] assert eqn.primitive == core.pack_p assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar newvar = pe.gensym("_cond") new_constvars, new_constvals = unzip2([ (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const) if old is None and new is not None ]) new_consts = consts + tuple(new_constvals) new_jaxpr = jaxpr.copy() new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars) newvars = iter(new_constvars) new_invars = [ next(newvars) if old is None and new is not None else (core.unitvar if new is None and old is None else v) for new, old, v in zip(new_pv, old_pv, eqn.invars) ] new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) + [_pack_eqn(new_invars, jaxpr.outvar)]) return new_jaxpr, new_consts
def default_process_primitive(self, primitive, tracers, params): """Partially evaluate primitives and saves variable recipes.""" pvs, consts = jax_util.unzip2(t.pval for t in tracers) if all(pv is None for pv in pvs): return primitive.bind(*consts, **params) settings = trace_util.get_dynamic_context(self).settings tracers = safe_map(self.instantiate_const, tracers) if any(not isinstance(t, UnzipTracer) for t in tracers): assert False key = all(t.is_key() for t in tracers) avals = [t.aval for t in tracers] ans = primitive.abstract_eval(*avals, **params) if not primitive.multiple_results: ans = [ans] out_tracers = [ UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key) for aval in ans ] # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params, source_info_util.current()) # pytype: disable=wrong-arg-types for t in out_tracers: t.recipe = eqn is_variable = (key and primitive is harvest.sow_p and params['tag'] == settings.tag) # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where # JaxprTrace will just return out_tracers, UnzipTrace will record an # additional VariableRecipe into the tracers, which will be used after # the trace is complete to construct init/apply Jaxprs. if is_variable: name, var_in_tracers, var_out_tracers = unzip_registry[primitive]( tracers, out_tracers, **params) variable_recipe = VariableRecipe(name, var_in_tracers, var_out_tracers) for t in out_tracers: t.variable_recipe = variable_recipe if primitive.multiple_results: return out_tracers return out_tracers[0]