def wrapper(*args, **kwargs): base.assert_context("optimize_rng_use") # Extract all current state. frame = base.current_frame() params = frame.params or None if params is not None: params = data_structures.to_haiku_dict(params) state = frame.state or None if state is not None: state = base.extract_state(state, initial=True) rng = frame.rng_stack.peek() if rng is not None: rng = rng.internal_state def pure_fun(params, state, rng, *args, **kwargs): with base.new_context(params=params, state=state, rng=rng): return fun(*args, **kwargs) with count_hk_rngs_requested() as rng_count_f: jax.eval_shape(pure_fun, params, state, rng, *args, **kwargs) rng_count = rng_count_f() if rng_count: base.current_frame().rng_stack.peek().reserve(rng_count) return fun(*args, **kwargs)
def merge(*structures: T) -> T: """Merges multiple input structures. >>> weights = {'linear': {'w': None}} >>> biases = {'linear': {'b': None}} >>> hk.data_structures.merge(weights, biases) {'linear': {'w': None, 'b': None}} When structures are not disjoint the output will contain the value from the last structure for each path: >>> weights1 = {'linear': {'w': 1}} >>> weights2 = {'linear': {'w': 2}} >>> hk.data_structures.merge(weights1, weights2) {'linear': {'w': 2}} Note: returns a new structure not a view. Args: *structures: One or more structures to merge. Returns: A single structure with an entry for each path in the input structures. """ out = collections.defaultdict(dict) for structure in structures: for module_name, name, value in traverse(structure): out[module_name][name] = value return data_structures.to_haiku_dict(out)
def filter( # pylint: disable=redefined-builtin predicate: Callable[[str, str, jnp.ndarray], bool], structure: T, ) -> T: """Filters an input structure according to a user specified predicate. >>> params = {'linear': {'w': None, 'b': None}} >>> predicate = lambda module_name, name, value: name == 'w' >>> hk.data_structures.filter(predicate, params) {'linear': {'w': None}} Note: returns a new structure not a view. Args: predicate: criterion to be used to partition the input data. The ``predicate`` argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data. structure: Haiku params or state data structure to be filtered. Returns: All the input parameters or state as selected by the input predicate. """ out = collections.defaultdict(dict) for module_name, name, value in traverse(structure): if predicate(module_name, name, value): out[module_name][name] = value return data_structures.to_haiku_dict(out)
def map( # pylint: disable=redefined-builtin fn: Callable[[str, str, InT], OutT], structure: Mapping[str, Mapping[str, InT]], ) -> Mapping[str, Mapping[str, OutT]]: """Maps a function to an input structure accordingly. >>> params = {'linear': {'w': 1.0, 'b': 2.0}} >>> fn = lambda module_name, name, value: 2 * value if name == 'w' else value >>> hk.data_structures.map(fn, params) {'linear': {'b': 2.0, 'w': 2.0}} Note: returns a new structure not a view. Args: fn: criterion to be used to map the input data. The ``fn`` argument is expected to be a boolean function taking as inputs the name of the module, the name of a given entry in the module data bundle (e.g. parameter name) and the corresponding data. structure: Haiku params or state data structure to be mapped. Returns: All the input parameters or state as mapped by the input fn. """ out = collections.defaultdict(dict) for module_name, name, value in traverse(structure): out[module_name][name] = fn(module_name, name, value) return data_structures.to_haiku_dict(out)
def extract_state(state: MutableState, *, initial) -> State: state = { m: {k: (v.initial if initial else v.current) for k, v in p.items()} for m, p in state.items() } state = data_structures.to_haiku_dict(state) return state
def test_copy(self, clone): before = data_structures.to_immutable_dict(dict(a=dict(b=1, c=2))) after = clone(before) self.assertIsNot(before, after) self.assertEqual(before, after) self.assertEqual(after, {"a": {"b": 1, "c": 2}}) before_dict = data_structures.to_haiku_dict(before) jax.tree_map(self.assertEqual, before_dict, after)
def new_context( *, params: Optional[Params] = None, state: Optional[State] = None, rng: Optional[Union[PRNGKey, int]] = None, ) -> HaikuContext: """Collects the results of hk.{get,set}_{parameter,state} calls. >>> with new_context(rng=jax.random.PRNGKey(42)) as ctx: ... mod = hk.nets.MLP([300, 100, 10]) ... y1 = mod(jnp.ones([1, 1])) >>> assert len(jax.tree_leaves(ctx.collect_params())) == 6 >>> with ctx: ... y2 = mod(jnp.ones([1, 1])) The same module instance in the same context will produce the same value: >>> assert (y1 == y2).all() Args: params: Optional parameter values to inject. state: Optional state values to inject. rng: Optional rng to inject. Returns: Context manager which closes over mutable Haiku internal state. """ if params is None: params = collections.defaultdict(dict) freeze_params = False else: params = data_structures.to_haiku_dict(params) freeze_params = True if state is None: state = collections.defaultdict(dict) else: state = { m: {k: StatePair(v, v) for k, v in p.items()} for m, p in state.items() } state = collections.defaultdict(dict, state) if rng is not None and not isinstance(rng, PRNGSequence): rng = PRNGSequence(rng) return HaikuContext(params, state, rng, freeze_params)
def merge( *structures: Mapping[str, Mapping[str, Any]], check_duplicates: bool = False, ) -> Mapping[str, Mapping[str, Any]]: """Merges multiple input structures. >>> weights = {'linear': {'w': None}} >>> biases = {'linear': {'b': None}} >>> hk.data_structures.merge(weights, biases) {'linear': {'w': None, 'b': None}} When structures are not disjoint the output will contain the value from the last structure for each path: >>> weights1 = {'linear': {'w': 1}} >>> weights2 = {'linear': {'w': 2}} >>> hk.data_structures.merge(weights1, weights2) {'linear': {'w': 2}} Note: returns a new structure not a view. Args: *structures: One or more structures to merge. check_duplicates: If True, a ValueError will be thrown if an array is found in multiple structures but with a different shape and dtype. Returns: A single structure with an entry for each path in the input structures. """ array_like = lambda o: hasattr(o, "shape") and hasattr(o, "dtype") shaped = lambda a: (a.shape, a.dtype) if array_like(a) else None fmt = lambda a: utils.format_array(a) if array_like(a) else repr(a) out = collections.defaultdict(dict) for structure in structures: for module_name, name, value in traverse(structure): if check_duplicates and (name in out[module_name]): previous = out[module_name][name] if shaped(previous) != shaped(value): raise ValueError( "Duplicate array found with different shape/dtype for " f"{module_name}.{name}: {fmt(previous)} vs {fmt(value)}.") out[module_name][name] = value return data_structures.to_haiku_dict(out)
def partition_n( fn: Callable[[str, str, jnp.ndarray], int], structure: T, n: int, ) -> Tuple[T, ...]: """Partitions a structure into `n` structures. For a given set of parameters, you can use :func:`partition_n` to split them into ``n`` groups. For example, to split your parameters/gradients by module name: >>> def partition_by_module(structure): ... cnt = itertools.count() ... d = collections.defaultdict(lambda: next(cnt)) ... fn = lambda m, n, v: d[m] ... return hk.data_structures.partition_n(fn, structure, len(structure)) >>> structure = {f'layer_{i}': {'w': None, 'b': None} for i in range(3)} >>> for substructure in partition_by_module(structure): ... print(substructure) {'layer_0': {'b': None, 'w': None}} {'layer_1': {'b': None, 'w': None}} {'layer_2': {'b': None, 'w': None}} Args: fn: Callable returning which bucket in ``[0, n)`` the given element should be output. structure: Haiku params or state data structure to be partitioned. n: The total number of buckets. Returns: A tuple of size ``n``, where each element will contain the values for which the function returned the current index. """ out = [collections.defaultdict(dict) for _ in range(n)] for module_name, name, value in traverse(structure): i = fn(module_name, name, value) assert 0 <= i < n, f"{i} must be in range [0, {n})" out[i][module_name][name] = value return tuple(data_structures.to_haiku_dict(o) for o in out)
def apply_fn(params, state, *args, **kwargs): del state out = f.apply(params, *args, **kwargs) state = data_structures.to_haiku_dict({}) return out, state
def init_fn(*args, **kwargs): params = f.init(*args, **kwargs) state = data_structures.to_haiku_dict({}) return params, state
def collect_params(self) -> Params: return data_structures.to_haiku_dict(self.__params)