def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=self, method_name=method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) # TODO(tomhennigan): With omnistaging primitives (like named call) will # stage out return values eagerly. For functions that produce non-Array # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be # returned that might result in a concretization error. For now we only # enable named call on __call__ (covering 99% of the interesting usages) # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term # we may want to split static and dynamic results in named call to support # other methods. if modules_with_named_call and module_name and method_name == "__call__": local_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_name) out = f(*args, **kwargs) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def jitted_fun(x, non_jaxtype): named_fun = stateful.named_call(fun_with_non_jaxtype_output) # The non-jaxtype is returned out of named_call (which is supported), # but is not returned out of the jit (which should not be supported). x, non_jaxtype_out = named_fun(x, non_jaxtype) self.assertEqual(non_jaxtype_out, non_jaxtype) return x
def test_named_call_jax_transforms(self, jax_transform): f = jnp.sum x = jnp.array([1.]) unnamed_out = jax_transform(f)(x) named_out = jax_transform(stateful.named_call(f, name="test"))(x) self.assertEqual(unnamed_out, named_out)
def test_named_call_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass # in a valid JaxType that forces the invalid Jaxtype to be raised to an # abstract value. def f(not_a_jaxtype, a_jaxtype): # then Jax needs to try and evaluate the abstractified non-JaxType if not_a_jaxtype: return a_jaxtype return 0 f = stateful.named_call(f, name="test") out = jax.jit(f, static_argnums=(0,))("not a Jaxtype", 1) self.assertEqual(out, 1)
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) # TODO(tomhennigan): With omnistaging primitives (like named call) will # stage out return values eagerly. For functions that produce non-Array # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be # returned that might result in a concretization error. For now we only # enable named call on __call__ (covering 99% of the interesting usages) # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term # we may want to split static and dynamic results in named call to support # other methods. if modules_with_named_call and module_name and method_name == "__call__": local_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`.") # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) if jax.config.jax_experimental_name_stack and module_name: local_module_name = module_name.split("/")[-1] f = jax.named_call(f, name=local_module_name) if method_name != "__call__": f = jax.named_call(f, name=method_name) elif module_name: # TODO(lenamartens): remove this branch once jax_experimental_name_stack # flag is removed. cfg = config.get_config() if cfg.profiler_name_scopes and method_name == "__call__": local_module_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_module_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def test_named_call_partial_function(self): f = stateful.named_call(lambda x, y: y if x else None) f = jax.jit(functools.partial(f, True)) out = f(5) self.assertEqual(out, 5)
def test_static_argnums_named_call(self): f = stateful.named_call(lambda x, y: y if x else None, name="test") f = jax.jit(f, static_argnums=(0,)) out = f(True, 5) self.assertEqual(out, 5)
def f(x): return stateful.named_call(SquareModule(), name="square")(x)