def test_method_hook(self): events = [] @contextlib.contextmanager def method_hook(mod, method_name): events.append( ("enter", method_name, getattr(mod, "module_name", None))) yield events.append(("exit", method_name, mod.module_name)) # Test __init__. with module.hook_methods(method_hook): m = EmptyModule() self.assertIsNotNone(m) self.assertEqual(events, [("enter", "__init__", None), ("exit", "__init__", "empty_module")]) # Test __call__. del events[:] m = CapturesModule(ScalarModule()) with module.hook_methods(method_hook): m() self.assertEqual(events, [("enter", "__call__", "captures_module"), ("enter", "__call__", "scalar_module"), ("exit", "__call__", "scalar_module"), ("exit", "__call__", "captures_module")])
def wrapped_fun(*args): """See `fun`.""" f = jax.linear_util.wrap_init(fun) args_flat, in_tree = jax.tree_flatten((args, {})) flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree) graph = Graph.create(title=name_or_str(fun)) @contextlib.contextmanager def method_hook(mod: module.Module, method_name: str): subg = Graph.create() with graph_stack(subg): yield title = mod.module_name if method_name != '__call__': title += f' ({method_name})' graph_stack.peek().subgraphs.append(subg.evolve(title=title)) with graph_stack(graph), \ module.hook_methods(method_hook), \ jax.core.new_main(DotTrace) as master: out_flat = _interpret_subtrace(flat_fun, master).call_wrapped(*args_flat) out = jax.tree_unflatten(out_tree(), out_flat) return graph, args, out
def test_callback_runs_after_submodules_updated(self): params = [] @contextlib.contextmanager def method_hook(mod, method_name): yield if method_name != "params_dict": params.append((mod.module_name, method_name, tuple(mod.params_dict()))) m = CapturesModule(ScalarModule()) with module.hook_methods(method_hook): m() self.assertEqual(params, [("scalar_module", "__call__", ("scalar_module/w",)), ("captures_module", "__call__", ("scalar_module/w",))])