Exemple #1
0
    def test_method_hook(self):
        events = []

        @contextlib.contextmanager
        def method_hook(mod, method_name):
            events.append(
                ("enter", method_name, getattr(mod, "module_name", None)))
            yield
            events.append(("exit", method_name, mod.module_name))

        # Test __init__.
        with module.hook_methods(method_hook):
            m = EmptyModule()
            self.assertIsNotNone(m)
            self.assertEqual(events, [("enter", "__init__", None),
                                      ("exit", "__init__", "empty_module")])

        # Test __call__.
        del events[:]
        m = CapturesModule(ScalarModule())
        with module.hook_methods(method_hook):
            m()
        self.assertEqual(events, [("enter", "__call__", "captures_module"),
                                  ("enter", "__call__", "scalar_module"),
                                  ("exit", "__call__", "scalar_module"),
                                  ("exit", "__call__", "captures_module")])
Exemple #2
0
    def wrapped_fun(*args):
        """See `fun`."""
        f = jax.linear_util.wrap_init(fun)
        args_flat, in_tree = jax.tree_flatten((args, {}))
        flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree)
        graph = Graph.create(title=name_or_str(fun))

        @contextlib.contextmanager
        def method_hook(mod: module.Module, method_name: str):
            subg = Graph.create()
            with graph_stack(subg):
                yield
            title = mod.module_name
            if method_name != '__call__':
                title += f' ({method_name})'
            graph_stack.peek().subgraphs.append(subg.evolve(title=title))

        with graph_stack(graph), \
             module.hook_methods(method_hook), \
             jax.core.new_main(DotTrace) as master:
            out_flat = _interpret_subtrace(flat_fun,
                                           master).call_wrapped(*args_flat)
        out = jax.tree_unflatten(out_tree(), out_flat)

        return graph, args, out
Exemple #3
0
  def test_callback_runs_after_submodules_updated(self):
    params = []
    @contextlib.contextmanager
    def method_hook(mod, method_name):
      yield
      if method_name != "params_dict":
        params.append((mod.module_name, method_name, tuple(mod.params_dict())))

    m = CapturesModule(ScalarModule())
    with module.hook_methods(method_hook):
      m()
    self.assertEqual(params,
                     [("scalar_module", "__call__", ("scalar_module/w",)),
                      ("captures_module", "__call__", ("scalar_module/w",))])