Example #1
0
    def wrapped(self, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        frame = base.current_frame()
        state = base.ModuleState(module=self, method_name=method_name)
        with frame.module(state), _module_method_call(self, method_name):
            # hk.Module enters the module name scope for all methods.
            module_name = getattr(self, "module_name", None)
            f = functools.partial(unbound_method, self)
            f = functools.partial(run_interceptors, f, method_name, self)
            # TODO(tomhennigan): With omnistaging primitives (like named call) will
            # stage out return values eagerly. For functions that produce non-Array
            # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be
            # returned that might result in a concretization error. For now we only
            # enable named call on __call__ (covering 99% of the interesting usages)
            # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term
            # we may want to split static and dynamic results in named call to support
            # other methods.
            if modules_with_named_call and module_name and method_name == "__call__":
                local_name = module_name.split("/")[-1]
                f = stateful.named_call(f, name=local_name)

            out = f(*args, **kwargs)

            # Notify parent modules about our existence.
            if module_name is not None:
                for module_state in frame.module_stack:
                    module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Example #2
0
 def jitted_fun(x, non_jaxtype):
   named_fun = stateful.named_call(fun_with_non_jaxtype_output)
   # The non-jaxtype is returned out of named_call (which is supported),
   # but is not returned out of the jit (which should not be supported).
   x, non_jaxtype_out = named_fun(x, non_jaxtype)
   self.assertEqual(non_jaxtype_out, non_jaxtype)
   return x
Example #3
0
  def test_named_call_jax_transforms(self, jax_transform):
    f = jnp.sum
    x = jnp.array([1.])

    unnamed_out = jax_transform(f)(x)
    named_out = jax_transform(stateful.named_call(f, name="test"))(x)

    self.assertEqual(unnamed_out, named_out)
Example #4
0
  def test_named_call_non_jaxtype_arg(self):
    # For the test to fail without the invalid JaxType filter we need to pass
    # in a valid JaxType that forces the invalid Jaxtype to be raised to an
    # abstract value.
    def f(not_a_jaxtype, a_jaxtype):
      # then Jax needs to try and evaluate the abstractified non-JaxType
      if not_a_jaxtype:
        return a_jaxtype
      return 0

    f = stateful.named_call(f, name="test")
    out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1)
    self.assertEqual(out, 1)
Example #5
0
    def wrapped(self, *args, **kwargs):
        """Calls the original method with a group name set before and after."""
        if not base.frame_stack:
            raise ValueError(
                "All `hk.Module`s must be initialized inside an `hk.transform`."
            )

        # Submodules are associated with this method. We allow users to associate
        # submodules with a different method than the one being called via
        # `@name_like("other_method")`. Interceptors and custom getters are still
        # provided the actual method name (e.g. "submodule_method_name" is only used
        # for naming submodules).
        submodule_method_name = getattr(unbound_method, _CUSTOM_NAME,
                                        method_name)

        frame = base.current_frame()
        state = base.ModuleState(module=self,
                                 method_name=submodule_method_name)
        with frame.module(state), _module_method_call(self, method_name):
            # hk.Module enters the module name scope for all methods.
            module_name = getattr(self, "module_name", None)
            f = functools.partial(unbound_method, self)
            f = functools.partial(run_interceptors, f, method_name, self)
            # TODO(tomhennigan): With omnistaging primitives (like named call) will
            # stage out return values eagerly. For functions that produce non-Array
            # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be
            # returned that might result in a concretization error. For now we only
            # enable named call on __call__ (covering 99% of the interesting usages)
            # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term
            # we may want to split static and dynamic results in named call to support
            # other methods.
            if modules_with_named_call and module_name and method_name == "__call__":
                local_name = module_name.split("/")[-1]
                f = stateful.named_call(f, name=local_name)

            out = f(*args, **kwargs)

            # Module names are set in the constructor. If `f` is the constructor then
            # its name will only be set **after** `f` has run. For methods other
            # than `__init__` we need the name before running in order to wrap their
            # execution with `named_call`.
            if module_name is None:
                module_name = getattr(self, "module_name", None)

            # Notify parent modules about our existence.
            if module_name is not None:
                for module_state in frame.module_stack:
                    if module_state.module is not self:
                        module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
        return out
Example #6
0
  def wrapped(self, *args, **kwargs):
    """Calls the original method with a group name set before and after."""
    if not base.frame_stack:
      raise ValueError(
          "All `hk.Module`s must be initialized inside an `hk.transform`.")

    # Submodules are associated with this method. We allow users to associate
    # submodules with a different method than the one being called via
    # `@name_like("other_method")`. Interceptors and custom getters are still
    # provided the actual method name (e.g. "submodule_method_name" is only used
    # for naming submodules).
    submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name)

    frame = base.current_frame()
    state = base.ModuleState(module=self, method_name=submodule_method_name)
    with frame.module(state), _module_method_call(self, method_name):
      # hk.Module enters the module name scope for all methods.
      module_name = getattr(self, "module_name", None)
      f = functools.partial(unbound_method, self)
      f = functools.partial(run_interceptors, f, method_name, self)
      if jax.config.jax_experimental_name_stack and module_name:
        local_module_name = module_name.split("/")[-1]
        f = jax.named_call(f, name=local_module_name)
        if method_name != "__call__":
          f = jax.named_call(f, name=method_name)
      elif module_name:
        # TODO(lenamartens): remove this branch once jax_experimental_name_stack
        # flag is removed.
        cfg = config.get_config()
        if cfg.profiler_name_scopes and method_name == "__call__":
          local_module_name = module_name.split("/")[-1]
          f = stateful.named_call(f, name=local_module_name)

      out = f(*args, **kwargs)

      # Module names are set in the constructor. If `f` is the constructor then
      # its name will only be set **after** `f` has run. For methods other
      # than `__init__` we need the name before running in order to wrap their
      # execution with `named_call`.
      if module_name is None:
        module_name = getattr(self, "module_name", None)

      # Notify parent modules about our existence.
      if module_name is not None:
        for module_state in frame.module_stack:
          if module_state.module is not self:
            module_state.module._submodules.add(module_name)  # pylint: disable=protected-access
    return out
Example #7
0
 def test_named_call_partial_function(self):
   f = stateful.named_call(lambda x, y: y if x else None)
   f = jax.jit(functools.partial(f, True))
   out = f(5)
   self.assertEqual(out, 5)
Example #8
0
 def test_static_argnums_named_call(self):
   f = stateful.named_call(lambda x, y: y if x else None,
                           name="test")
   f = jax.jit(f, static_argnums=(0,))
   out = f(True, 5)
   self.assertEqual(out, 5)
Example #9
0
 def f(x):
   return stateful.named_call(SquareModule(), name="square")(x)