Ejemplo n.º 1
0
 def test_persists_original_fn(self, without):
     orig_f = lambda: None
     f = transform.transform(orig_f)
     if without is not None:
         f = without(f)
     self.assertIs(transform.get_original_fn(f), orig_f)
     self.assertIs(transform.get_original_fn(f.init), orig_f)
     self.assertIs(transform.get_original_fn(f.apply), orig_f)
Ejemplo n.º 2
0
def eval_summary(
    f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState],
) -> Callable[..., Sequence[MethodInvocation]]:
    """Records module method calls performed by ``f``.

  >>> f = lambda x: hk.nets.MLP([300, 100, 10])(x)
  >>> x = jnp.ones([8, 28 * 28])
  >>> for i in hk.experimental.eval_summary(f)(x):
  ...   print("mod := {:14} | in := {} out := {}".format(
  ...       i.module_details.module.module_name, i.args_spec[0], i.output_spec))
  mod := mlp            | in := f32[8,784] out := f32[8,10]
  mod := mlp/~/linear_0 | in := f32[8,784] out := f32[8,300]
  mod := mlp/~/linear_1 | in := f32[8,300] out := f32[8,100]
  mod := mlp/~/linear_2 | in := f32[8,100] out := f32[8,10]

  Args:
    f: A function or transformed function to trace.

  Returns:
    A callable taking the same arguments as the provided function, but returning
    a sequence of :class:`MethodInvocation` instances revealing the methods
    called on each module when applying ``f``.

  See Also:
    :func:`tabulate`: Pretty prints a summary of the execution of a function.
  """
    sidechannel = data_structures.ThreadLocalStack()

    try:
        f = transform.get_original_fn(f)
    except AttributeError:
        pass

    def f_logged(*args, **kwargs):
        used_modules = sidechannel.peek()
        logging_interceptor = functools.partial(log_used_modules, used_modules)

        with hk.intercept_methods(logging_interceptor):
            f(*args, **kwargs)

    # We know that we will only evaluate this function once and that inside
    # eval_shape we will re-trace any jitted/pmap-ed code. This allows users to
    # pass in jit/pmap decorated apply functions (e.g. train_step).
    f = make_hk_transform_ignore_jax_transforms(f)

    f_orig = hk.transform_with_state(f)
    f_logged = hk.transform_with_state(f_logged)

    def init_apply(*args, **kwargs):
        init_rng, apply_rng = jax.random.split(jax.random.PRNGKey(42))
        params, state = f_orig.init(init_rng, *args, **kwargs)
        f_logged.apply(params, state, apply_rng, *args, **kwargs)

    def wrapper(*args, **kwargs) -> Sequence[MethodInvocation]:
        used_modules = []
        with sidechannel(used_modules):
            jax.eval_shape(init_apply, *args, **kwargs)
        return used_modules

    return wrapper
Ejemplo n.º 3
0
 def assertPersistsOriginal(self, f, orig_f):
   self.assertIs(transform.get_original_fn(f), orig_f)
   self.assertIs(transform.get_original_fn(f.init), orig_f)
   self.assertIs(transform.get_original_fn(f.apply), orig_f)