def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share. instantiate = [instantiate] * len(out_tracers) out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) out_tracers = safe_map(partial(pe.instantiate_const_at, trace), instantiate, out_tracers) jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) out_pvals = [t.pval for t in out_tracers] # TODO: this is from partial_eval.trace_to_jaxpr. Share. assert not env # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr out_avals = safe_map(abstract_arrays.raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple( abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts) in_pvals = [t.pval for t in in_tracers] in_avals = tuple( safe_map(abstract_arrays.raise_to_shaped, unzip2(in_pvals)[0])) typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts
def new(cls, val): if val is jax_core.unit: return InverseAndILDJ.unknown(jax_core.abstract_unit) val = np.array(val) aval = jax_core.get_aval(val) aval = abstract_arrays.raise_to_shaped(aval) ndslice = NDSlice.new(val, np.zeros_like(val)) return InverseAndILDJ(aval, frozenset([ndslice]))
def _initial_style_jaxpr(fun, in_tree, in_avals): in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals] fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True) out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0]) const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts) typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr), (), const_avals + in_avals, out_avals) return typed_jaxpr, consts, out_tree()
def instantiate_const_abstracted(self, tracer): pv, const = tracer.pval if isinstance(pv, jax_core.AbstractValue): return tracer elif pv is None: aval = abstract_arrays.raise_to_shaped( trace_util.get_shaped_aval(const), onp.isscalar(const)) return UnzipTracer(self, pe.PartialVal.unknown(aval), pe.ConstVar(const), tracer.is_key()) else: raise TypeError(pv)
def handle_sow(self, *values, name, tag, tree, mode): """Stores a sow in the reaps dictionary.""" del tag if name in self.reaps: raise ValueError(f'Variable has already been reaped: {name}') avals = tree_util.tree_unflatten( tree, [abstract_arrays.raise_to_shaped(jax_core.get_aval(v)) for v in values]) self.reaps[name] = Reap( tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals)) return values
def custom_jvp_call_jaxpr(fun, jvp, *args): """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive.""" in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) jvp_jaxpr_thunk = pe._memoize( # pylint: disable=protected-access lambda: cd._initial_style_jaxpr(jvp, in_avals * 2)) # pylint: disable=protected-access return cd.custom_jvp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, jvp_jaxpr_thunk=jvp_jaxpr_thunk, num_consts=len(consts))
def _get_harvest_metadata(closed_jaxpr, settings, *args): """Probes a jaxpr for metadata like its sown values.""" fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr)) with jax_core.new_main(HarvestTrace) as main: settings = HarvestSettings(settings.tag, settings.blocklist, settings.allowlist, True) fun = reap_function(fun, main, settings, True) fun, aux = _reap_metadata_wrapper(fun) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) in_avals = jax_util.safe_map( lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)), flat_args) pe.trace_to_jaxpr_final(flat_fun, in_avals) metadata = aux() out_tree() return metadata
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees): in_avals = [ abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args ] fun_jaxpr, consts = cd._initial_style_jaxpr( # pylint: disable=protected-access fun, in_avals) # consts can be tracers! closed_fun_jaxpr = jax_core.ClosedJaxpr( pe.convert_constvars_jaxpr(fun_jaxpr), ()) fwd_jaxpr_thunk = pe._memoize( lambda: cd._initial_style_jaxpr(fwd, in_avals)) # pylint: disable=protected-access return cd.custom_vjp_call_jaxpr_p.bind(*consts, *args, fun_jaxpr=closed_fun_jaxpr, fwd_jaxpr_thunk=fwd_jaxpr_thunk, bwd=bwd, out_trees=out_trees, num_consts=len(consts))
def aval(self): return abstract_arrays.raise_to_shaped(jax_core.get_aval(self.val))
def get_shaped_aval(x): """Converts a JAX value type into a shaped abstract value.""" if hasattr(x, 'dtype') and hasattr(x, 'shape'): return abstract_arrays.ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype)) return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))
def abstractify(x): return abstract_arrays.raise_to_shaped(core.get_aval(x))
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis): input_aval = raise_to_shaped(x) shape = list(input_aval.shape) size = shape.pop(split_axis) shape.insert(concat_axis, size) return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
if axis_index_groups is not None: size = len(axis_index_groups[0]) elif type(axis_name) is tuple: size = prod([core.axis_frame(name).size for name in axis_name]) # type: ignore else: size = core.axis_frame(axis_name).size # type: ignore return tuple(size * x for x in args) return core.Primitive.bind(psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups) pmax_p = core.Primitive('pmax') pmax_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) xla.parallel_translations[pmax_p] = \ partial(_allreduce_translation_rule, lax.max_p) batching.split_axis_rules[pmax_p] = partial(_split_axis_comm_assoc, pmax_p) batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p) batching.collective_rules[pmax_p] = \ partial(_batched_reduction_collective, pmax_p, lambda v, d: v.max(d), lambda v, axis_size: v) pmin_p = core.Primitive('pmin') pmin_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x)) xla.parallel_translations[pmin_p] = \ partial(_allreduce_translation_rule, lax.min_p) batching.split_axis_rules[pmin_p] = partial(_split_axis_comm_assoc, pmin_p)
def _gamma_batching_rule(batched_args, batch_dims): k, a = batched_args bk, ba = batch_dims size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None) k = batching.bdim_at_front(k, bk, size) a = batching.bdim_at_front(a, ba, size) return random_gamma_p.bind(k, a), (0, ) random_gamma_p = core.Primitive('random_gamma') random_gamma_p.multiple_results = True random_gamma_p.def_impl(_gamma_impl) random_gamma_p.def_abstract_eval(lambda key, a: (abstract_arrays.raise_to_shaped(a), )) ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: (tangent * _gamma_grad(ans[0], a), )) xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl) batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule def gamma(key, a, shape=None, dtype=onp.float64): """Sample Gamma random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. a: a float or array of floats broadcast-compatible with ``shape`` representing the parameter of the distribution. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``a``. The default (None)
else: samples = vmap(_gamma_one)(keys, alphas) return jnp.reshape(samples, a_shape), def _gamma_batching_rule(batched_args, batch_dims): k, a = batched_args bk, ba = batch_dims size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None) k = batching.bdim_at_front(k, bk, size) a = batching.bdim_at_front(a, ba, size) return random_gamma_p.bind(k, a), (0,) random_gamma_p = core.Primitive('random_gamma') random_gamma_p.multiple_results = True random_gamma_p.def_impl(_gamma_impl) random_gamma_p.def_abstract_eval(lambda key, a: (abstract_arrays.raise_to_shaped(a),)) ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: (tangent * _gamma_grad(ans[0], a),)) xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl) batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule def gamma(key, a, shape=None, dtype=np.float64): """Sample Gamma random values with given shape and float dtype. Args: key: a PRNGKey used as the random key. a: a float or array of floats broadcast-compatible with ``shape`` representing the parameter of the distribution. shape: optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with ``a``. The default (None) produces a result shape equal to ``a.shape``. dtype: optional, a float dtype for the returned values (default float64 if
def new(cls, val): aval = jax_core.get_aval(val) if aval is jax_core.abstract_unit: return cls.unknown(aval) aval = abstract_arrays.raise_to_shaped(aval) return InverseAndILDJ(aval, val, np.array(0.))
def typematch(aval1, aval2): return raise_to_shaped(aval1) == raise_to_shaped(aval2)
def typecheck(aval, x): aval = raise_to_shaped(aval) try: return aval == core.lattice_join(aval, core.get_aval(x)) except TypeError: return False
def _abstractify(x): return raise_to_shaped(core.get_aval(x))
def get_shaped_aval(x): if hasattr(x, 'dtype') and hasattr(x, 'shape'): return abstract_arrays.ShapedArray(x.shape, x.dtype) return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll): """Collects and injects values into/from the scan body.""" context = trace_util.get_dynamic_context(trace) settings = context.settings values = [t.val for t in tracers] consts, init, xs = jax_util.split_list(values, [num_consts, num_carry]) active_sows = _find_sows(jaxpr, settings.tag) active_modes = [params['mode'] for params in active_sows] if any(mode == 'strict' for mode in active_modes): raise ValueError('Cannot use strict mode in a scan.') active_names = [params['name'] for params in active_sows] sow_modes = {name: mode for name, mode in zip(active_names, active_modes)} carry_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'clobber' } xs_plants = { name: context.plants[name] for name in active_names if name in context.plants and sow_modes[name] == 'append' } def jaxpr_fun(carry, x): body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals, *(consts + carry + x)) carry, y = jax_util.split_list(body_out, [num_carry]) return carry, y harvest_body = harvest(jaxpr_fun, tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist) def body(carry, x): x_plants, x_vals = x (carry, y), reaps = harvest_body({ **carry_plants, **x_plants }, carry, x_vals) return carry, (y, reaps) xs_flat = tree_util.tree_leaves((xs_plants, xs)) x_avals = [] for x in xs_flat: x_aval = jax_core.get_aval(x) if x_aval is jax_core.abstract_unit: x_avals.append(x_aval) else: x_shape, x_dtype = masking.padded_shape_as_value( x.shape[1:]), x.dtype x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype)) x_avals = tuple(x_avals) init_avals = tuple( abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init) in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs))) body_jaxpr, new_consts, out_tree = ( jax.lax.lax_control_flow._initial_style_jaxpr( # pylint: disable=protected-access body, in_tree, init_avals + x_avals)) new_values = list(new_consts) + in_flat num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts) remaining_linear = linear[num_consts:] new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] + (False, ) * num_xs_plants + remaining_linear[len(init):]) assert len(new_linear) == len(new_values) outs = lax.scan_p.bind(*new_values, length=length, reverse=reverse, jaxpr=body_jaxpr, num_consts=len(new_consts), num_carry=num_carry, linear=new_linear, unroll=unroll) outs = safe_map(trace.pure, outs) carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs) out_reaps = {} for k, val in reaps.items(): mode = sow_modes.get(k, 'strict') if mode == 'append': val = tree_util.tree_map(np.concatenate, val) elif mode == 'clobber': val = tree_util.tree_map(lambda x: x[-1], val) out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict') (carry, ys) = prim.tie_in(out_reaps, (carry, ys)) return carry + ys