Exemplo n.º 1
0
 def abstract_call(*inputs):
     key_and_inputs = (ShapedArray((2, ), 'uint32'), ) + inputs
     flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten(
         key_and_inputs)
     flat_fun, self._cached_out_tree = jax.flatten_fun_nokwargs(
         self._init_and_apply, in_tree_with_rng)
     flat_partial_inputs = [
         PartialVal((a, jc.unit)) for a in flat_rng_and_inputs
     ]
     _, flat_partial_outs, _ = trace_to_jaxpr(flat_fun,
                                              flat_partial_inputs,
                                              instantiate=True)
     flat_outs, _ = unzip2(flat_partial_outs)
     return flat_outs
Exemplo n.º 2
0
        def merge(tracers):
            flat_inputs, submodule_params_iters = unzip2(
                (t.val, t.submodule_params_iter) for t in tracers)

            submodule_param_iter = None
            for iter in submodule_params_iters:
                if isinstance(iter, ApplyTrace.SubmoduleParamsIterator):
                    assert submodule_param_iter is None or iter is submodule_param_iter
                    submodule_param_iter = iter
                else:
                    assert isinstance(iter, dict)
                    assert len(iter) == 0

            return flat_inputs, submodule_param_iter
Exemplo n.º 3
0
def _instantiated_trace_to_jaxpr(fun, avals):
    pvals = map(lambda aval: PartialVal((aval, unit)), avals)
    jaxpr, out_pvals, consts = trace_to_jaxpr(fun, pvals, instantiate=True)
    out_avals, _ = unzip2(out_pvals)
    return jaxpr, out_avals, consts