コード例 #1
0
ファイル: core.py プロジェクト: juliuskunze/jaxnet
 def _init_and_apply_parameters_dict(self, *example_inputs, key):
     flat_inputs, in_tree = tree_flatten(example_inputs)
     flat_fun, out_tree_thunk = flatten_fun_nokwargs(self._wrapped_fun, in_tree)
     flat_init_fun, get_parameters_thunk = _init_transform(flat_fun, key)
     flat_outputs = flat_init_fun.call_wrapped(*flat_inputs)
     outputs = tree_unflatten(out_tree_thunk(), flat_outputs)
     return get_parameters_thunk(), outputs
コード例 #2
0
ファイル: core.py プロジェクト: juliuskunze/jaxnet
 def abstract_eval(self, *avals, **kwargs):
     in_tree, out_tree_container = split_dict(kwargs, ['in_tree', 'out_tree_container'])
     flat_outs_fun, out_tree_thunk = flatten_fun_nokwargs(self._wrapped_example_outputs_fun,
                                                          in_tree)
     # populates out_tree_thunk, so that it returns the output tree:
     _, flat_outs, _ = _instantiated_trace_to_jaxpr(flat_outs_fun, avals)
     # return out_tree via container:
     out_tree_container.append(out_tree_thunk())
     return flat_outs
コード例 #3
0
ファイル: core.py プロジェクト: juliuskunze/jaxnet
 def _apply(self, parameters, *inputs, key):
     flat_inputs, in_tree = tree_flatten(inputs)
     flat_fun, out_tree = flatten_fun_nokwargs(self._wrapped_fun, in_tree)
     apply_trace = _top_trace(filter_type=ApplyTrace)
     with new_main(ApplyTrace) as master:
         global_parameters_by_primitive = apply_trace.state.global_parameters_by_primitive \
             if apply_trace else {}
         random_state = apply_trace.state.random_state if apply_trace else RandomState(key)
         master.state = ApplyTraceState(random_state, parameters, global_parameters_by_primitive)
         flat_outputs = _apply_transform(flat_fun, master).call_wrapped(*flat_inputs)
         del master
     return tree_unflatten(out_tree(), flat_outputs)
コード例 #4
0
 def abstract_call(*inputs):
     key_and_inputs = (ShapedArray((2, ), 'uint32'), ) + inputs
     flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten(
         key_and_inputs)
     flat_fun, self._cached_out_tree = jax.flatten_fun_nokwargs(
         self._init_and_apply, in_tree_with_rng)
     flat_partial_inputs = [
         PartialVal((a, jc.unit)) for a in flat_rng_and_inputs
     ]
     _, flat_partial_outs, _ = trace_to_jaxpr(flat_fun,
                                              flat_partial_inputs,
                                              instantiate=True)
     flat_outs, _ = unzip2(flat_partial_outs)
     return flat_outs
コード例 #5
0
    def _out_tree(self, *inputs):
        if self._cached_out_tree is not None:
            result = self._cached_out_tree()
            self._cached_out_tree = None
            return result

        flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten(
            (parametrized.dummy_rng, ) + inputs)
        flat_fun, out_tree = jax.flatten_fun_nokwargs(self._init_and_apply,
                                                      in_tree_with_rng)
        # Need to abstract eval in order to build out tree:
        pe.trace_to_jaxpr(flat_fun,
                          parametrized._partialize(flat_rng_and_inputs),
                          instantiate=True)
        return out_tree()
コード例 #6
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = jax.tree_flatten(args)
     flat_fun, out_tree = jax.flatten_fun_nokwargs(fun, in_tree)
     ans = jax_core.call_p.bind(flat_fun, *flat_args)
     return jax.tree_unflatten(out_tree(), ans)
コード例 #7
0
ファイル: core.py プロジェクト: juliuskunze/jaxnet
 def call_flattened(*flat_inputs, in_tree, out_tree_container):
     flat_fun, _ = flatten_fun_nokwargs(self._wrapped_fun, in_tree)
     return flat_fun.call_wrapped(*flat_inputs)