Exemplo n.º 1
0
def eval_summary(
    f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState],
) -> Callable[..., Sequence[MethodInvocation]]:
    """Records module method calls performed by ``f``.

  >>> f = lambda x: hk.nets.MLP([300, 100, 10])(x)
  >>> x = jnp.ones([8, 28 * 28])
  >>> for i in hk.experimental.eval_summary(f)(x):
  ...   print("mod := {:14} | in := {} out := {}".format(
  ...       i.module_details.module.module_name, i.args_spec[0], i.output_spec))
  mod := mlp            | in := f32[8,784] out := f32[8,10]
  mod := mlp/~/linear_0 | in := f32[8,784] out := f32[8,300]
  mod := mlp/~/linear_1 | in := f32[8,300] out := f32[8,100]
  mod := mlp/~/linear_2 | in := f32[8,100] out := f32[8,10]

  Args:
    f: A function or transformed function to trace.

  Returns:
    A callable taking the same arguments as the provided function, but returning
    a sequence of :class:`MethodInvocation` instances revealing the methods
    called on each module when applying ``f``.

  See Also:
    :func:`tabulate`: Pretty prints a summary of the execution of a function.
  """
    sidechannel = data_structures.ThreadLocalStack()

    try:
        f = transform.get_original_fn(f)
    except AttributeError:
        pass

    def f_logged(*args, **kwargs):
        used_modules = sidechannel.peek()
        logging_interceptor = functools.partial(log_used_modules, used_modules)

        with hk.intercept_methods(logging_interceptor):
            f(*args, **kwargs)

    # We know that we will only evaluate this function once and that inside
    # eval_shape we will re-trace any jitted/pmap-ed code. This allows users to
    # pass in jit/pmap decorated apply functions (e.g. train_step).
    f = make_hk_transform_ignore_jax_transforms(f)

    f_orig = hk.transform_with_state(f)
    f_logged = hk.transform_with_state(f_logged)

    def init_apply(*args, **kwargs):
        init_rng, apply_rng = jax.random.split(jax.random.PRNGKey(42))
        params, state = f_orig.init(init_rng, *args, **kwargs)
        f_logged.apply(params, state, apply_rng, *args, **kwargs)

    def wrapper(*args, **kwargs) -> Sequence[MethodInvocation]:
        used_modules = []
        with sidechannel(used_modules):
            jax.eval_shape(init_apply, *args, **kwargs)
        return used_modules

    return wrapper
Exemplo n.º 2
0
    def test_stack_per_thread(self):
        s = data_structures.ThreadLocalStack()
        self.assertEmpty(s)
        s.push(42)

        s_len_second_thread = [None]

        def second_thread():
            self.assertEmpty(s)
            s.push(666)
            s.push(777)
            s_len_second_thread[0] = len(s)

        t = threading.Thread(target=second_thread)
        t.start()
        t.join()

        self.assertEqual(s_len_second_thread[0], 2)
        self.assertEqual(s.pop(), 42)
        self.assertEmpty(s)
Exemplo n.º 3
0
import html
from typing import Any, Callable, NamedTuple, List, Optional

from haiku._src import data_structures
from haiku._src import module
from haiku._src import utils
import jax

# Import tree if available, but only throw error at runtime.
# Permits us to drop dm-tree from deps.
try:
    import tree  # pylint: disable=g-import-not-at-top
except ImportError as e:
    tree = None

graph_stack = data_structures.ThreadLocalStack()
Node = collections.namedtuple('Node', 'id,title,outputs')
Edge = collections.namedtuple('Edge', 'a,b')


class Graph(NamedTuple):
    """Represents a graphviz digraph/subgraph.."""

    title: str
    nodes: List[Node]
    edges: List[Edge]
    subgraphs: List['Graph']

    @classmethod
    def create(cls, title: str = None):
        return Graph(title=title, nodes=[], edges=[], subgraphs=[])