def abstract_call(*inputs): key_and_inputs = (ShapedArray((2, ), 'uint32'), ) + inputs flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten( key_and_inputs) flat_fun, self._cached_out_tree = jax.flatten_fun_nokwargs( self._init_and_apply, in_tree_with_rng) flat_partial_inputs = [ PartialVal((a, jc.unit)) for a in flat_rng_and_inputs ] _, flat_partial_outs, _ = trace_to_jaxpr(flat_fun, flat_partial_inputs, instantiate=True) flat_outs, _ = unzip2(flat_partial_outs) return flat_outs
def merge(tracers): flat_inputs, submodule_params_iters = unzip2( (t.val, t.submodule_params_iter) for t in tracers) submodule_param_iter = None for iter in submodule_params_iters: if isinstance(iter, ApplyTrace.SubmoduleParamsIterator): assert submodule_param_iter is None or iter is submodule_param_iter submodule_param_iter = iter else: assert isinstance(iter, dict) assert len(iter) == 0 return flat_inputs, submodule_param_iter
def _instantiated_trace_to_jaxpr(fun, avals): pvals = map(lambda aval: PartialVal((aval, unit)), avals) jaxpr, out_pvals, consts = trace_to_jaxpr(fun, pvals, instantiate=True) out_avals, _ = unzip2(out_pvals) return jaxpr, out_avals, consts