def _callback_fun(callback, strip_calls, *in_vals, **params): with core.new_main(CallbackTrace) as main: main.callback = callback # NOTE: Is this OK? main.strip_calls = strip_calls out_vals = yield (main, ) + in_vals, params del main yield out_vals
def _batch_outer(axis_name, axis_size, in_dims, main_type, *in_vals): with core.new_main(main_type, axis_name=axis_name) as main: with core.extend_axis_env(axis_name, axis_size, main): with source_info_util.transform_name_stack('vmap'): outs = yield (main, in_dims, *in_vals), {} del main yield outs
def check_errors_toplevel(*args): error = init_error with core.new_main(ErrorTrace) as main: msgs = tuple(error.msgs.items()) outs = yield (main, msgs, error.err, error.code, *args), {} del main yield outs
def _callback_fun(callback, strip_calls, *in_vals, **params): with core.new_main(CallbackTrace, callback=callback, strip_calls=strip_calls) as main: out_vals = yield (main, ) + in_vals, params del main yield out_vals
def _batch_fun(in_dims, *in_vals, **params): with jax_core.new_main(batching.BatchTrace, axis_name=None) as main: out_vals = yield ( main, in_dims, ) + in_vals, params del main yield out_vals
def jet_fun(order, primals, series): with core.new_main(JetTrace) as main: main.order = order out_primals, out_terms = yield (main, primals, series), {} del main out_terms = [[np.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] yield out_primals, out_terms
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) with jax_core.new_main(HarvestTrace) as main: flat_fun = reap_function(flat_fun, main, settings, False) out_flat, reaps = flat_fun.call_wrapped(flat_args) del main return tree_util.tree_unflatten(out_tree(), out_flat), reaps
def wrapped(plants, *args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) all_args, all_tree = tree_util.tree_flatten((plants, flat_args)) with jax_core.new_main(HarvestTrace) as main: flat_fun = plant_function(flat_fun, main, settings, all_tree) out_flat = flat_fun.call_wrapped(all_args) del main return tree_util.tree_unflatten(out_tree(), out_flat)
def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes): env_keys, padded_env_vals = unzip2(sorted(padded_env.items())) logical_env_vals = [logical_env[k] for k in env_keys] # Make padded_env hashable padded_env = (env_keys, padded_env_vals) with core.new_main(MaskTrace) as main: fun, out_shapes = mask_subtrace(fun, main, polymorphic_shapes, padded_env) out_vals = fun.call_wrapped(*(logical_env_vals + in_vals)) del main return out_vals, out_shapes()
def sparsify_fun(wrapped_fun, args: List[ArrayOrSparse]): with core.new_main(SparseTrace) as main: spenv = SparsifyEnv() spvalues = arrays_to_spvalues(spenv, args) in_bufs = spenv._buffers fun, out_spvalues = sparsify_subtrace(wrapped_fun, main, spvalues) out_bufs = fun.call_wrapped(*in_bufs) spenv = SparsifyEnv(out_bufs) del main return spvalues_to_arrays(spenv, out_spvalues())
def _batch_jaxpr_outer(axis_name, axis_size, in_dims, main_type, *in_vals): if axis_size is None: axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} in_dims = in_dims() if callable(in_dims) else in_dims in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) else ax for x, ax in zip(in_vals, in_dims)] with core.new_main(main_type, axis_name=axis_name) as main: with core.extend_axis_env(axis_name, axis_size, main): out_vals = yield (main, in_dims, *in_vals), {} del main yield out_vals
def doubling_transform(*args): with core.new_main(DoublingTrace) as main: trace = DoublingTrace(main, core.cur_sublevel()) in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args] outputs = yield in_tracers, {} if isinstance(outputs, Sequence): out_tracers = map(trace.full_raise, outputs) result = [(x.head, x.tail) for x in out_tracers] else: out_tracer = trace.full_raise(outputs) result = (out_tracer.head, out_tracer.tail) yield result
def _apply(self, parameters, *inputs, key): flat_inputs, in_tree = tree_flatten(inputs) flat_fun, out_tree = flatten_fun_nokwargs(self._wrapped_fun, in_tree) apply_trace = _top_trace(filter_type=ApplyTrace) with new_main(ApplyTrace) as master: global_parameters_by_primitive = apply_trace.state.global_parameters_by_primitive \ if apply_trace else {} random_state = apply_trace.state.random_state if apply_trace else RandomState(key) master.state = ApplyTraceState(random_state, parameters, global_parameters_by_primitive) flat_outputs = _apply_transform(flat_fun, master).call_wrapped(*flat_inputs) del master return tree_unflatten(out_tree(), flat_outputs)
def jvpfun(instantiate, transform_stack, primals, tangents): tangents = [Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with core.new_main(JVPTrace) as main, ctx: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] yield out_primals, out_tangents
def _init_transform(key, *inputs): """Transforms a flattened `parametrized` function into its corresponding `init_parameters` function.""" init_trace = _top_trace(filter_type=InitTrace) with new_main(InitTrace) as master: global_parameters_dict = init_trace.state.global_parameters_dict if init_trace else {} random_state = init_trace.state.random_state if init_trace else RandomState(key) master.state = InitTraceState(random_state, global_parameters_dict) trace = InitTrace(master, cur_sublevel()) outs = yield map(trace.full_raise, inputs), {} outs = trace.lower_all(outs) parameters_dict = master.state.parameters_dict del master yield outs, parameters_dict
def wrapped(*args, **kwargs): """Callable returned by unzip.""" with jax_core.new_main(UnzipTrace) as master: # Preparing args to be traced fun = lu.wrap_init(f, kwargs) avals = tree_util.tree_map(trace_util.get_shaped_aval, args) flat_avals, flat_keys, in_tree = (flatten_args_into_keys( avals, key_args)) flat_pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals] flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) # Trace to jaxpr settings = UnzipSettings(tag, False) fun = unzip_to_init_apply_subjaxprs(flat_fun, master, settings) # pylint: disable=no-value-for-parameter success, results = fun.call_wrapped(flat_keys, flat_pvals) if not success: raise ValueError('Variables do not cut dependence graph.') init_out, apply_out, _, metadata = results init_jaxpr, init_consts, init_env = init_out assert not init_env apply_jaxpr, apply_consts, apply_env = apply_out assert not apply_env names, variable_tree, _ = metadata out_tree = out_tree() # Final functions def init(*args): flat_args, _ = tree_util.tree_flatten(args) flat_params = jax_core.eval_jaxpr(init_jaxpr, init_consts, *flat_args) flat_variables = tree_util.tree_unflatten( variable_tree, flat_params) return { name: var for name, var in safe_zip(names, flat_variables) } def apply(params, *args): flat_variables, _ = tree_util.tree_flatten( [params[name] for name in names]) flat_args, _ = tree_util.tree_flatten(args) out = jax_core.eval_jaxpr(apply_jaxpr, apply_consts, *(flat_variables + flat_args)) return tree_util.tree_unflatten(out_tree, out) del master return init, apply
def jvpfun(instantiate, primals, tangents): tangents = [ Zero.from_value(t) if not isinstance(t, Zero) and dtype(t) is float0 else t for t in tangents ] with core.new_main(JVPTrace) as main: out_primals, out_tangents = yield (main, primals, tangents), {} del main if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [ instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate) ] yield out_primals, out_tangents
def _get_harvest_metadata(closed_jaxpr, settings, *args): """Probes a jaxpr for metadata like its sown values.""" fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr)) with jax_core.new_main(HarvestTrace) as main: settings = HarvestSettings(settings.tag, settings.blocklist, settings.allowlist, True) fun = reap_function(fun, main, settings, True) fun, aux = _reap_metadata_wrapper(fun) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) in_avals = jax_util.safe_map( lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)), flat_args) pe.trace_to_jaxpr_final(flat_fun, in_avals) metadata = aux() out_tree() return metadata
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[core.AbstractValue]): with core.new_main(DJaxprTrace, dynamic=True) as main: main.jaxpr_stack = () # type: ignore outs = trace_to_subjaxpr_dynamic(fun, main, in_avals) del main return outs
def checkify_traceable(msgs, enabled_errors, err, code, payload, *args): with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main: outs = yield (main, msgs, err, code, payload, *args), {} del main yield outs
def _batch_outer(axis_name, axis_size, in_dims, main_type, *in_vals): with core.new_main(main_type, axis_name=axis_name) as main: with core.extend_axis_env(axis_name, axis_size, main): outs = yield (main, in_dims, *in_vals), {} del main yield outs
def check_errors_traceable(msgs, err, code, *args): with core.new_main(ErrorTrace) as main: outs = yield (main, msgs, err, code, *args), {} del main yield outs