def wrapper(*args, **kwargs): base.assert_context("optimize_rng_use") # Extract all current state. frame = base.current_frame() params = frame.params or None if params is not None: params = data_structures.to_haiku_dict(params) state = frame.state or None if state is not None: state = base.extract_state(state, initial=True) rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state def pure_fun(params, state, rng, *args, **kwargs): with base.new_context(params=params, state=state, rng=rng): return fun(*args, **kwargs) with count_hk_rngs_requested() as rng_count_f: jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs) rng_count = rng_count_f() if rng_count: base.current_frame().rng_stack.peek().reserve(rng_count) return fun(*args, **kwargs)
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=self, method_name=method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) # TODO(tomhennigan): With omnistaging primitives (like named call) will # stage out return values eagerly. For functions that produce non-Array # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be # returned that might result in a concretization error. For now we only # enable named call on __call__ (covering 99% of the interesting usages) # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term # we may want to split static and dynamic results in named call to support # other methods. if modules_with_named_call and module_name and method_name == "__call__": local_name = module_name.split("/")[-1] f = named_call.stateful_named_call(f, name=local_name) out = f(*args, **kwargs) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def __call__(self, *args, **kwargs): frame = base.current_frame() outer_params = frame.params outer_state = frame.state if hk.running_init(): inner_params, inner_state = self._init_fn(*args, **kwargs) # Lift parameters into this transform's params_dict. check_param_reuse = not self._allow_reuse pack_into_dict(inner_params, outer_params, self._prefix_name, check_param_reuse=check_param_reuse) pack_into_dict(inner_state, outer_state, self._prefix_name, state=True, check_param_reuse=check_param_reuse) return inner_params, inner_state else: if self._prefix_name: prefix = f"{self._prefix_name}/" else: prefix = "" inner_params = unpack_from_dict(outer_params, prefix) inner_state = unpack_from_dict(outer_state, prefix) inner_state = base.extract_state(inner_state, initial=False) inner_params = hk.data_structures.to_haiku_dict(inner_params) inner_state = hk.data_structures.to_haiku_dict(inner_state) return inner_params, inner_state
def temporary_internal_state(state: InternalState): rng = state.rng if rng is not None: rng = base.PRNGSequence(rng) frame = base.current_frame() frame = frame.evolve(params=state.params, state=state.state, rng=rng) return base.frame_stack(frame)
def wrapped(module, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=module, method_name=method_name) with frame.module(state), _module_method_call(module, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(module, "module_name", None) f = functools.partial(unbound_method, module) f = functools.partial(run_interceptors, f, method_name, module) if modules_with_named_call and module_name: local_name = module_name.split("/")[-1] f = named_call.stateful_named_call(f, name=local_name) out = f(*args, **kwargs) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def __init__(self, name: Optional[str] = None): """Initializes the current module with the given name. Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly. Args: name: An optional string name for the class. Must be a valid Python identifier. If ``name`` is not provided then the class name for the current instance is converted to ``lower_snake_case`` and used instead. """ if name is None: if hasattr(self, "name") and self.name is not None: # Attribute assigned by @dataclass constructor. name = self.name else: name = utils.camel_to_snake(type(self).__name__) if not valid_identifier(name): raise ValueError( "'{}' is not a valid module name (must be a valid Python identifier)" .format(name)) self._submodules: Set[str] = set() self.module_name = unique_and_canonical_name(name) self.name = self.module_name.split("/")[-1] self._creation_frame_id = base.current_frame().frame_id
def internal_state() -> InternalState: frame = base.current_frame() rng = frame.rng_stack.peek() if rng is not None: rng = rng.peek() return InternalState(params=copy_structure(frame.params), state=copy_structure(frame.state), rng=rng)
def update_internal_state(state: InternalState): frame = base.current_frame() if not frame.params_frozen and state.params is not None: update_recursive(frame.params, state.params) update_recursive(frame.state, state.state) rng = state.rng if rng is not None: frame.rng_stack.peek().replace_internal_state(rng)
def reserve_up_to_full_rng_block(): """If RNG is active in the current frame, reserve up to the default block.""" # TODO(lenamartens): Fix needing full block reservation in stateful # control-flow by keeping track of current key with index and keeping a full # block in PRNGSequence at all time. rng_seq = base.current_frame().rng_stack.peek() if rng_seq: rng_seq.reserve_up_to_full()
def get_frame_data() -> FrameData: frame = current_frame() rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state return FrameData(params=copy_structure(frame.params), state=copy_structure(frame.state), constants=copy_structure(frame.constants), rng=copy_structure(rng))
def internal_state(*, params=True) -> InternalState: frame = base.current_frame() rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state return InternalState( params=(copy_structure(frame.params) if params else None), state=copy_structure(frame.state), rng=copy_structure(rng))
def __call__(self, *args, **kwargs): outer_params = base.current_frame().params if hk.running_init(): inner_params = self._init_fn(*args, **kwargs) # Lift parameters into this transform's params_dict. pack_into_dict(inner_params, outer_params, self.module_name) return inner_params else: return unpack_from_dict(outer_params, f"{self.module_name}/")
def update(self, state: hk.State): """Updates Haiku's internal state to the given state.""" frame = base.current_frame() for mod_name, bundle in state.items(): if self._name is not None: mod_name = f"{self._name}/{mod_name}" for name, value in bundle.items(): initial_pair = base.StatePair(value, value) initial = frame.state[mod_name].get(name, initial_pair).initial frame.state[mod_name][name] = base.StatePair(initial, value)
def update_modified_frame_data(frame_data: FrameData): frame = current_frame() if not params_frozen(): update_recursive_skip_none(frame.params, frame_data.params) update_recursive_skip_none(frame.state, frame_data.state) if not params_frozen(): update_recursive_skip_none(frame.constants, frame_data.constants) rng = frame_data.rng if rng is not None: frame.rng_stack.peek().replace_internal_state(rng)
def update_modified_frame_data_from_args(params, bundled_state): (state, constants, rng) = bundled_state frame = current_frame() if not params_frozen(): update_recursive_skip_none(frame.params, params) update_recursive_skip_none(frame.state, state) if not params_frozen(): update_recursive_skip_none(frame.constants, constants) rng = rng if rng is not None: frame.rng_stack.peek().replace_internal_state(rng)
def __call__(self, *args, **kwargs): frame = base.current_frame() bundle_name = self.module_name if base.in_apply(): prefix = bundle_name + "/" lifted_params = unpack_from_dict(frame.params, prefix) return lifted_params else: # Inside init. # Lift parameters into this transform's params_dict. params = self._init_fn(*args, **kwargs) pack_into_dict(params, frame.params, bundle_name) return params
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) # TODO(tomhennigan): With omnistaging primitives (like named call) will # stage out return values eagerly. For functions that produce non-Array # values (e.g. `def is_batched(self, x) -> bool`) a tracer will be # returned that might result in a concretization error. For now we only # enable named call on __call__ (covering 99% of the interesting usages) # with an assumption that __call__ is `f(*) -> Tree[Array]`. Longer term # we may want to split static and dynamic results in named call to support # other methods. if modules_with_named_call and module_name and method_name == "__call__": local_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def temporary_frame_data(frame_data: FrameData): """Pushes a temporary copy of the frame_data.""" frame_data = copy_structure(frame_data) rng = frame_data.rng if frame_data.rng is None else PRNGSequence(frame_data.rng) params = frame_data.params state = frame_data.state constants = frame_data.constants assert params is not None, "Must initialize module before this call" assert state is not None, "Must initialize module before this call" assert constants is not None, "Must initialize module before this call" frame = current_frame() frame = frame.evolve(params=params, state=state, constants=constants, rng=rng) return frame_stack(frame)
def wrapped(self, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`.") # Submodules are associated with this method. We allow users to associate # submodules with a different method than the one being called via # `@name_like("other_method")`. Interceptors and custom getters are still # provided the actual method name (e.g. "submodule_method_name" is only used # for naming submodules). submodule_method_name = getattr(unbound_method, _CUSTOM_NAME, method_name) frame = base.current_frame() state = base.ModuleState(module=self, method_name=submodule_method_name) with frame.module(state), _module_method_call(self, method_name): # hk.Module enters the module name scope for all methods. module_name = getattr(self, "module_name", None) f = functools.partial(unbound_method, self) f = functools.partial(run_interceptors, f, method_name, self) if jax.config.jax_experimental_name_stack and module_name: local_module_name = module_name.split("/")[-1] f = jax.named_call(f, name=local_module_name) if method_name != "__call__": f = jax.named_call(f, name=method_name) elif module_name: # TODO(lenamartens): remove this branch once jax_experimental_name_stack # flag is removed. cfg = config.get_config() if cfg.profiler_name_scopes and method_name == "__call__": local_module_name = module_name.split("/")[-1] f = stateful.named_call(f, name=local_module_name) out = f(*args, **kwargs) # Module names are set in the constructor. If `f` is the constructor then # its name will only be set **after** `f` has run. For methods other # than `__init__` we need the name before running in order to wrap their # execution with `named_call`. if module_name is None: module_name = getattr(self, "module_name", None) # Notify parent modules about our existence. if module_name is not None: for module_state in frame.module_stack: if module_state.module is not self: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def temporary_internal_state(state: InternalState): """Pushes a temporary copy of the internal state.""" state = copy_structure(state) rng = state.rng if rng is not None: rng = base.PRNGSequence(rng) current_state = internal_state() params = state.params if params is None: params = current_state.params state = state.state if state is None: state = current_state.state frame = base.current_frame() frame = frame.evolve(params=params, state=state, rng=rng) return base.frame_stack(frame)
def __call__(self, *args, **kwargs): frame = base.current_frame() bundle_name = self.module_name if _SENTINEL_NAME in frame.params[bundle_name]: prefix = bundle_name + "/" lifted_params = unpack_from_dict(frame.params, prefix) lifted_state = unpack_from_dict(frame.state, prefix) return lifted_params, lifted_state else: # Ensure sentinel is set for apply. base.get_parameter(_SENTINEL_NAME, (), init=jnp.zeros) # Lift parameters into this transform's params_dict. params, state = self._init_fn(*args, **kwargs) pack_into_dict(params, frame.params, bundle_name) pack_into_dict(state, frame.state, bundle_name) return params, state
def get_constant(name: str, value: Any, init=None, do_not_set=False): constants = current_frame().constants[current_bundle_name()] saved_value = constants.get(name, None) if saved_value is None: if do_not_set: return None if init is not None: value = init(value) constants[name] = value else: constants[name] = value else: assert name in constants, f"Missing {name} in constants" value = saved_value return value
def unique_and_canonical_name(name: str) -> str: """Returns a canonical name for the given name.""" frame = base.current_frame() # If we are outside init/call then prefix the name with the method name. if len(frame.module_stack) > 1: # -2 since we are inside the ctor and want to look at the caller state. module_state = frame.module_stack.peek(-2) # Make sure to include the method name if appropriate. method_name = module_state.method_name if method_name == "__init__": name = "~/" + name elif method_name != "__call__": name = "~" + method_name + "/" + name # Include the parent name. parent_module = module_state.module parent_name = base.safe_get_module_name(parent_module) name = parent_name + "/" + name # Test if the user has explicitly numbered this module. splits = re.split(r"_(\d+)$", name, 3) if len(splits) > 1: name, n = splits[0], int(splits[1]) explicit_n = True else: n = None explicit_n = False # Determine a unique name for this module within the current context. counters = frame.counter_stack.peek(-2) if n is not None: counters[name] = max(counters[name], n + 1) else: n = counters[name] counters[name] += 1 qualified_name = f"{name}_{n}" if explicit_n or n else name # Final sanity check that this name has not been used before. used_names = frame.used_names_stack.peek(-2) if qualified_name in used_names: raise ValueError(f"Module name '{qualified_name}' is not unique.") used_names.add(qualified_name) return qualified_name
def params_dict(self) -> Mapping[base.ParamName, jnp.array]: """Returns parameters keyed by name for this module and submodules.""" if not base.frame_stack: raise ValueError( "`module.params_dict()` must be used as part of an `hk.transform`.") params = {} curr_name = self.module_name for mod_name, mod_params in base.current_frame().params.items(): if (mod_name == curr_name or mod_name.startswith(curr_name + "/") or mod_name in self._submodules): for param_name, param in mod_params.items(): fq_name = mod_name + "/" + param_name params[fq_name] = param return params
def params_or_state_dict( module_name: str, submodules: Set[str], which: str, ) -> Mapping[str, jnp.array]: """Returns module parameters or state for the given module or submodules.""" assert which in ("params", "state") out = {} frame = base.current_frame() for their_module_name, bundle in getattr(frame, which).items(): if (their_module_name == module_name or their_module_name.startswith(module_name + "/") or their_module_name in submodules): for name, value in bundle.items(): fq_name = their_module_name + "/" + name out[fq_name] = value.current if which == "state" else value return out
def wrapped(module, *args, **kwargs): """Calls the original method with a group name set before and after.""" if not base.frame_stack: raise ValueError( "All `hk.Module`s must be initialized inside an `hk.transform`." ) frame = base.current_frame() state = base.ModuleState(module=module, method_name=method_name) with frame.module(state), _module_method_call(module, method_name): # hk.Module enters the module name scope for all methods. out = unbound_method(module, *args, **kwargs) # Notify parent modules about our existence. module_name = getattr(module, "module_name", None) if module_name is not None: for module_state in frame.module_stack: module_state.module._submodules.add(module_name) # pylint: disable=protected-access return out
def get_call_stack() -> Sequence[ModuleDetails]: frame = base.current_frame() return tuple( map(lambda s: ModuleDetails.of(s.module, s.method_name), list(frame.module_stack)))
def wrapped(params, state, *args, **kwargs): update_frame_data(params, state) out = module(*args, **kwargs) state = copy_structure(current_frame().state) return out, state
def get_bundled_state(): frame = current_frame() rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state return (frame.state, frame.constants, rng)
def update_frame_data(params, state): frame = current_frame() if not params_frozen(): update_recursive_skip_none(frame.params, params) update_recursive_skip_none(frame.state, state)