def _callback_fun(callback, strip_calls, *in_vals, **params): with new_master(CallbackTrace) as master: master.callback = callback # NOTE: Is this OK? master.strip_calls = strip_calls out_vals = yield (master, ) + in_vals, params del master yield out_vals
def _interpret_fun(fun: lu.WrappedFun, in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]: with core.new_master(TensorFlowTrace) as master: fun = _interpret_subtrace(fun, master) out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals) del master return out_vals
def jet_fun(primals, series): with core.new_master(JetTrace) as master: out_primals, out_terms = yield (master, primals, series), {} del master out_terms = [tree_map(lambda x: onp.zeros_like(x, dtype=onp.result_type(out_primals[0])), series[0]) if s is zero_series else s for s in out_terms] yield out_primals, out_terms
def jet_fun(order, primals, series): with core.new_master(JetTrace) as master: master.order = order out_primals, out_terms = yield (master, primals, series), {} del master out_terms = [[np.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] yield out_primals, out_terms
def apply_transform(net_params, inputs): with jc.new_master(ApplyTrace) as master: trace = ApplyTrace(master, jc.cur_sublevel()) ans = yield map(partial(ApplyTracer, trace, net_params), inputs), {} out_tracer = trace.full_raise(ans) out_val = out_tracer.val del master, out_tracer yield out_val
def jet_transform(primals, series): with core.new_master(JetTrace) as master: trace = JetTrace(master, core.cur_sublevel()) in_tracers = map(partial(JetTracer, trace), primals, series) ans = yield in_tracers, {} out_tracers = map(trace.full_raise, ans) out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) yield out_primals, out_terms
def inner(): flat_inputs, in_tree = tree_util.tree_flatten(inputs) flat_fun, out_tree = api_util.flatten_fun_nokwargs( self._wrapped_fun, in_tree) with jc.new_master(ApplyTrace) as master: flat_fun = ApplyTrace._apply_subtrace( flat_fun, master, WrapHashably(params)) flat_outputs = flat_fun.call_wrapped(*inputs) del master return tree_util.tree_unflatten(out_tree(), flat_outputs)
def wrapped(plants, *args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) all_args, all_tree = tree_util.tree_flatten((plants, flat_args)) with jax_core.new_master(HarvestTrace) as master: flat_fun = harvest_function(flat_fun, master, settings, all_tree) out_flat, reaped = flat_fun.call_wrapped(all_args) del master out = tree_util.tree_unflatten(out_tree(), out_flat) return out, reaped
def doubling_transform(*args): with core.new_master(DoublingTrace) as master: trace = DoublingTrace(master, core.cur_sublevel()) in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args] outputs = yield in_tracers, {} if isinstance(outputs, Sequence): out_tracers = map(trace.full_raise, outputs) result = [(x.head, x.tail) for x in out_tracers] else: out_tracer = trace.full_raise(outputs) result = (out_tracer.head, out_tracer.tail) yield result
def wrapped(*args, **kwargs): """Callable returned by unzip.""" with jax_core.new_master(UnzipTrace) as master: # Preparing args to be traced trace = UnzipTrace(master, jax_core.cur_sublevel()) fun = lu.wrap_init(f, kwargs) avals = tree_util.tree_map(trace_util.get_shaped_aval, args) flat_avals, flat_keys, in_tree = (flatten_args_into_keys( avals, key_args)) flat_pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals] flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) # Trace to jaxpr settings = UnzipSettings(tag, False) fun = unzip_to_init_apply_subjaxprs(flat_fun, trace, settings) # pylint: disable=no-value-for-parameter success, results = fun.call_wrapped(flat_keys, flat_pvals) if not success: raise ValueError('Variables do not cut dependence graph.') init_out, apply_out, _, metadata = results init_jaxpr, init_consts, init_env = init_out assert not init_env apply_jaxpr, apply_consts, apply_env = apply_out assert not apply_env names, variable_tree, _ = metadata out_tree = out_tree() # Final functions def init(*args): flat_args, _ = tree_util.tree_flatten(args) flat_params = jax_core.eval_jaxpr(init_jaxpr, init_consts, *flat_args) flat_variables = tree_util.tree_unflatten( variable_tree, flat_params) return { name: var for name, var in safe_zip(names, flat_variables) } def apply(params, *args): flat_variables, _ = tree_util.tree_flatten( [params[name] for name in names]) flat_args, _ = tree_util.tree_flatten(args) out = jax_core.eval_jaxpr(apply_jaxpr, apply_consts, *(flat_variables + flat_args)) return tree_util.tree_unflatten(out_tree, out) del master return init, apply
def _init_transform(key, *inputs): """Transforms a flattened `parametrized` function into its corresponding `init_parameters` function.""" init_trace = _top_trace(filter_type=InitTrace) with new_master(InitTrace) as master: global_parameters_dict = init_trace.state.global_parameters_dict if init_trace else {} random_state = init_trace.state.random_state if init_trace else RandomState( key) master.state = InitTraceState(random_state, global_parameters_dict) trace = InitTrace(master, cur_sublevel()) outs = yield map(trace.full_raise, inputs), {} outs = trace.lower_all(outs) parameters_dict = master.state.parameters_dict del master yield outs, parameters_dict
def _apply(self, parameters, *inputs, key): flat_inputs, in_tree = tree_flatten(inputs) flat_fun, out_tree = flatten_fun_nokwargs(self._wrapped_fun, in_tree) apply_trace = _top_trace(filter_type=ApplyTrace) with new_master(ApplyTrace) as master: global_parameters_by_primitive = apply_trace.state.global_parameters_by_primitive \ if apply_trace else {} random_state = apply_trace.state.random_state if apply_trace else RandomState( key) master.state = ApplyTraceState(random_state, parameters, global_parameters_by_primitive) flat_outputs = _apply_transform(flat_fun, master).call_wrapped(*flat_inputs) del master return tree_unflatten(out_tree(), flat_outputs)
def jet_fun(primals, series): with core.new_master(JetTrace) as master: out_primals, out_terms = yield (master, primals, series), {} del master yield out_primals, out_terms