def eval_summary( f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState], ) -> Callable[..., Sequence[MethodInvocation]]: """Records module method calls performed by ``f``. >>> f = lambda x: hk.nets.MLP([300, 100, 10])(x) >>> x = jnp.ones([8, 28 * 28]) >>> for i in hk.experimental.eval_summary(f)(x): ... print("mod := {:14} | in := {} out := {}".format( ... i.module_details.module.module_name, i.args_spec[0], i.output_spec)) mod := mlp | in := f32[8,784] out := f32[8,10] mod := mlp/~/linear_0 | in := f32[8,784] out := f32[8,300] mod := mlp/~/linear_1 | in := f32[8,300] out := f32[8,100] mod := mlp/~/linear_2 | in := f32[8,100] out := f32[8,10] Args: f: A function or transformed function to trace. Returns: A callable taking the same arguments as the provided function, but returning a sequence of :class:`MethodInvocation` instances revealing the methods called on each module when applying ``f``. See Also: :func:`tabulate`: Pretty prints a summary of the execution of a function. """ sidechannel = data_structures.ThreadLocalStack() try: f = transform.get_original_fn(f) except AttributeError: pass def f_logged(*args, **kwargs): used_modules = sidechannel.peek() logging_interceptor = functools.partial(log_used_modules, used_modules) with hk.intercept_methods(logging_interceptor): f(*args, **kwargs) # We know that we will only evaluate this function once and that inside # eval_shape we will re-trace any jitted/pmap-ed code. This allows users to # pass in jit/pmap decorated apply functions (e.g. train_step). f = make_hk_transform_ignore_jax_transforms(f) f_orig = hk.transform_with_state(f) f_logged = hk.transform_with_state(f_logged) def init_apply(*args, **kwargs): init_rng, apply_rng = jax.random.split(jax.random.PRNGKey(42)) params, state = f_orig.init(init_rng, *args, **kwargs) f_logged.apply(params, state, apply_rng, *args, **kwargs) def wrapper(*args, **kwargs) -> Sequence[MethodInvocation]: used_modules = [] with sidechannel(used_modules): jax.eval_shape(init_apply, *args, **kwargs) return used_modules return wrapper
def test_stack_per_thread(self): s = data_structures.ThreadLocalStack() self.assertEmpty(s) s.push(42) s_len_second_thread = [None] def second_thread(): self.assertEmpty(s) s.push(666) s.push(777) s_len_second_thread[0] = len(s) t = threading.Thread(target=second_thread) t.start() t.join() self.assertEqual(s_len_second_thread[0], 2) self.assertEqual(s.pop(), 42) self.assertEmpty(s)
import html from typing import Any, Callable, NamedTuple, List, Optional from haiku._src import data_structures from haiku._src import module from haiku._src import utils import jax # Import tree if available, but only throw error at runtime. # Permits us to drop dm-tree from deps. try: import tree # pylint: disable=g-import-not-at-top except ImportError as e: tree = None graph_stack = data_structures.ThreadLocalStack() Node = collections.namedtuple('Node', 'id,title,outputs') Edge = collections.namedtuple('Edge', 'a,b') class Graph(NamedTuple): """Represents a graphviz digraph/subgraph..""" title: str nodes: List[Node] edges: List[Edge] subgraphs: List['Graph'] @classmethod def create(cls, title: str = None): return Graph(title=title, nodes=[], edges=[], subgraphs=[])