def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=self, method_name=method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) # TODO(tomhennigan): With omnistaging primitives (like named call) will # stage out return values eagerly. For functions that produce non-Array # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be # returned that might result in a concretization error. For now we only # enable named call on __call__ (covering 99% of the interesting usages) # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term # we may want to split static and dynamic results in named call to support # other methods. if modules_with_named_call and module_name and method_name == "__call__": local_name = module_name.split("/")[-1] f = named_call.stateful_named_call(f, name=local_name) out = f(*args, **kwargs) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def wrapped(module, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=module, method_name=method_name) with frame.module(state), _module_method_call(module, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(module, "module_name", None) f = functools.partial(unbound_method, module) f = functools.partial(run_interceptors, f, method_name, module) if modules_with_named_call and module_name: local_name = module_name.split("/")[-1] f = named_call.stateful_named_call(f, name=local_name) out = f(*args, **kwargs) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) # TODO(tomhennigan): With omnistaging primitives (like named call) will # stage out return values eagerly. For functions that produce non-Array # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be # returned that might result in a concretization error. For now we only # enable named call on __call__ (covering 99% of the interesting usages) # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term # we may want to split static and dynamic results in named call to support # other methods. if modules_with_named_call and module_name and method_name == "__call__": local_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`.") # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) if jax.config.jax_experimental_name_stack and module_name: local_module_name = module_name.split("/")[-1] f = jax.named_call(f, name=local_module_name) if method_name != "__call__": f = jax.named_call(f, name=method_name) elif module_name: # TODO(lenamartens): remove this branch once jax_experimental_name_stack # flag is removed. cfg = config.get_config() if cfg.profiler_name_scopes and method_name == "__call__": local_module_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_module_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def wrapped(module, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=module, method_name=method_name) with frame.module(state), _module_method_call(module, method_name): # hk.Module enters the module name scope for all methods. out = unbound_method(module, *args, **kwargs) # Notify parent modules about our existence. module_name = getattr(module, "module_name", None) if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def simulate_module_call(module): frame = base.current_frame() state = base.ModuleState(module=module, method_name="__call__") return frame.module(state)