def test_persists_original_fn(self, without): orig_f = lambda: None f = transform.transform(orig_f) if without is not None: f = without(f) self.assertIs(transform.get_original_fn(f), orig_f) self.assertIs(transform.get_original_fn(f.init), orig_f) self.assertIs(transform.get_original_fn(f.apply), orig_f)
def eval_summary( f: Union[Callable[..., Any], hk.Transformed, hk.TransformedWithState], ) -> Callable[..., Sequence[MethodInvocation]]: """Records module method calls performed by ``f``. >>> f = lambda x: hk.nets.MLP([300, 100, 10])(x) >>> x = jnp.ones([8, 28 * 28]) >>> for i in hk.experimental.eval_summary(f)(x): ... print("mod := {:14} | in := {} out := {}".format( ... i.module_details.module.module_name, i.args_spec[0], i.output_spec)) mod := mlp | in := f32[8,784] out := f32[8,10] mod := mlp/~/linear_0 | in := f32[8,784] out := f32[8,300] mod := mlp/~/linear_1 | in := f32[8,300] out := f32[8,100] mod := mlp/~/linear_2 | in := f32[8,100] out := f32[8,10] Args: f: A function or transformed function to trace. Returns: A callable taking the same arguments as the provided function, but returning a sequence of :class:`MethodInvocation` instances revealing the methods called on each module when applying ``f``. See Also: :func:`tabulate`: Pretty prints a summary of the execution of a function. """ sidechannel = data_structures.ThreadLocalStack() try: f = transform.get_original_fn(f) except AttributeError: pass def f_logged(*args, **kwargs): used_modules = sidechannel.peek() logging_interceptor = functools.partial(log_used_modules, used_modules) with hk.intercept_methods(logging_interceptor): f(*args, **kwargs) # We know that we will only evaluate this function once and that inside # eval_shape we will re-trace any jitted/pmap-ed code. This allows users to # pass in jit/pmap decorated apply functions (e.g. train_step). f = make_hk_transform_ignore_jax_transforms(f) f_orig = hk.transform_with_state(f) f_logged = hk.transform_with_state(f_logged) def init_apply(*args, **kwargs): init_rng, apply_rng = jax.random.split(jax.random.PRNGKey(42)) params, state = f_orig.init(init_rng, *args, **kwargs) f_logged.apply(params, state, apply_rng, *args, **kwargs) def wrapper(*args, **kwargs) -> Sequence[MethodInvocation]: used_modules = [] with sidechannel(used_modules): jax.eval_shape(init_apply, *args, **kwargs) return used_modules return wrapper
def assertPersistsOriginal(self, f, orig_f): self.assertIs(transform.get_original_fn(f), orig_f) self.assertIs(transform.get_original_fn(f.init), orig_f) self.assertIs(transform.get_original_fn(f.apply), orig_f)