def _init_and_apply_parameters_dict(self, *example_inputs, key): flat_inputs, in_tree = tree_flatten(example_inputs) flat_fun, out_tree_thunk = flatten_fun_nokwargs(self._wrapped_fun, in_tree) flat_init_fun, get_parameters_thunk = _init_transform(flat_fun, key) flat_outputs = flat_init_fun.call_wrapped(*flat_inputs) outputs = tree_unflatten(out_tree_thunk(), flat_outputs) return get_parameters_thunk(), outputs
def abstract_eval(self, *avals, **kwargs): in_tree, out_tree_container = split_dict(kwargs, ['in_tree', 'out_tree_container']) flat_outs_fun, out_tree_thunk = flatten_fun_nokwargs(self._wrapped_example_outputs_fun, in_tree) # populates out_tree_thunk, so that it returns the output tree: _, flat_outs, _ = _instantiated_trace_to_jaxpr(flat_outs_fun, avals) # return out_tree via container: out_tree_container.append(out_tree_thunk()) return flat_outs
def _apply(self, parameters, *inputs, key): flat_inputs, in_tree = tree_flatten(inputs) flat_fun, out_tree = flatten_fun_nokwargs(self._wrapped_fun, in_tree) apply_trace = _top_trace(filter_type=ApplyTrace) with new_main(ApplyTrace) as master: global_parameters_by_primitive = apply_trace.state.global_parameters_by_primitive \ if apply_trace else {} random_state = apply_trace.state.random_state if apply_trace else RandomState(key) master.state = ApplyTraceState(random_state, parameters, global_parameters_by_primitive) flat_outputs = _apply_transform(flat_fun, master).call_wrapped(*flat_inputs) del master return tree_unflatten(out_tree(), flat_outputs)
def abstract_call(*inputs): key_and_inputs = (ShapedArray((2, ), 'uint32'), ) + inputs flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten( key_and_inputs) flat_fun, self._cached_out_tree = jax.flatten_fun_nokwargs( self._init_and_apply, in_tree_with_rng) flat_partial_inputs = [ PartialVal((a, jc.unit)) for a in flat_rng_and_inputs ] _, flat_partial_outs, _ = trace_to_jaxpr(flat_fun, flat_partial_inputs, instantiate=True) flat_outs, _ = unzip2(flat_partial_outs) return flat_outs
def _out_tree(self, *inputs): if self._cached_out_tree is not None: result = self._cached_out_tree() self._cached_out_tree = None return result flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten( (parametrized.dummy_rng, ) + inputs) flat_fun, out_tree = jax.flatten_fun_nokwargs(self._init_and_apply, in_tree_with_rng) # Need to abstract eval in order to build out tree: pe.trace_to_jaxpr(flat_fun, parametrized._partialize(flat_rng_and_inputs), instantiate=True) return out_tree()
def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = jax.tree_flatten(args) flat_fun, out_tree = jax.flatten_fun_nokwargs(fun, in_tree) ans = jax_core.call_p.bind(flat_fun, *flat_args) return jax.tree_unflatten(out_tree(), ans)
def call_flattened(*flat_inputs, in_tree, out_tree_container): flat_fun, _ = flatten_fun_nokwargs(self._wrapped_fun, in_tree) return flat_fun.call_wrapped(*flat_inputs)