def fast_eval_shape(fun, *args, **kwargs): """Equivalent to ``eval_shape`` in JAX. This utility is equivalent to ``eval_shape`` in JAX except that it avoids running Haiku functions whose shapes are trivially known. This can avoid some Python overheads in JAX which can accumulate for very large models. Optimizations: * All parameter/state initialisers replaced with zeros. * ``hk.dropout`` replaced with identity. * ``jax.random.fold_in`` replaced with identity. Args: fun: The function to trace. *args: Positional arguments to ``fun``. **kwargs: Keyword arguments to ``fun``. Returns: The shape produced by ``fun`` for the given args/kwargs. """ with base.custom_creator_unsafe(zeros_creator), \ mock.patch.object(basic, 'dropout_impl', noop_dropout), \ mock.patch.object(jax.random, 'fold_in', lambda key, data: key): if base.inside_transform(): return stateful.eval_shape(fun, *args, **kwargs) else: return jax.eval_shape(fun, *args, **kwargs)
def cond(*args, **kwargs): """Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError( "hk.cond() should not be used outside of hk.transform(). " "Use jax.cond() instead.") try: bound_args = inspect.signature(_old_cond).bind(*args, **kwargs) except TypeError: bound_args = inspect.signature(_new_cond).bind(*args, **kwargs) pred, true_fun, false_fun, operand = bound_args.args else: pred, true_operand, true_fun, false_operand, false_fun = bound_args.args true_fun = lambda op, f=true_fun: f(op[0]) false_fun = lambda op, f=false_fun: f(op[1]) operand = (true_operand, false_operand) reserve_up_to_full_rng_block() stateful_branch_mem = _memoize_by_id(stateful_branch) state = internal_state() out, state = jax.lax.cond(pred, true_fun=stateful_branch_mem(true_fun), false_fun=stateful_branch_mem(false_fun), operand=(state, operand)) update_internal_state(state) return out
def fori_loop(lower, upper, body_fun, init_val): """Equivalent to ``jax.lax.fori_loop`` with Haiku state threaded in/out.""" if not base.inside_transform(): raise ValueError( "hk.fori_loop() should not be used outside of hk.transform(). " "Use jax.lax.fori_loop() instead.") def pure_body_fun(i, val): state, val = val with temporary_internal_state(state): val = body_fun(i, val) state = internal_state() return state, val if not base.params_frozen(): # During init we need to unwind one step of the loop to ensure the Haiku # state before and after the body has the same structure. init_val = body_fun(lower, init_val) lower += 1 if upper - lower == 0: return init_val state = internal_state() init_val = state, init_val state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val) update_internal_state(state) return val
def fori_loop(lower, upper, body_fun, init_val): """Equivalent to :func:`jax.lax.fori_loop` with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError( "hk.fori_loop() should not be used outside of hk.transform(). " "Use jax.lax.fori_loop() instead.") def pure_body_fun(i, val): state, val = val with temporary_internal_state(state): val = body_fun(i, val) reserve_up_to_full_rng_block() state = internal_state() return state, val if not base.params_frozen(): # During init we need to unwind one step of the loop to ensure the Haiku # state before and after the body has the same structure. init_val = body_fun(lower, init_val) lower += 1 try: if upper - lower == 0: return init_val except jax.errors.ConcretizationTypeError: # upper or lower might be tracers, which jax.lax.fori_loop can handle. pass reserve_up_to_full_rng_block() state = internal_state() init_val = state, init_val state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val) update_internal_state(state) return val
def wrapped_dec_fun(fun, *dec_args, **dec_kwargs): """Decorates a modified version of ``fun`` that passes Haiku state.""" if not base.inside_transform(): raise ValueError( "hk.{0}() should not be used outside of hk.transform. " "Use jax.{0}() instead.".format(dec_fun.__name__)) @functools.wraps(fun) def stateful_fun(*args, **kwargs): state_in = kwargs.pop("hk_state") with temporary_internal_state(state_in, share_python_state=True): out = fun(*args, **kwargs) return out, difference(state_in, internal_state()) dec_stateful_fun = dec_fun(stateful_fun, *dec_args, **dec_kwargs) @functools.wraps(dec_stateful_fun) def wrapper(*args, **kwargs): kwargs["hk_state"] = internal_state() out, state = dec_stateful_fun(*args, **kwargs) update_internal_state(state) return out return wrapper
def wrapper(*args, **kwargs): if base.inside_transform(): stateful_named_call = thread_hk_state_in_kwargs(jax.named_call) named_fun = stateful_named_call(fun, name=name) else: named_fun = jax.named_call(fun, name=name) out = named_fun(*args, **kwargs) return out
def scan(f, init, xs, length=None, reverse=False): """Equivalent to `jax.lax.scan` but with Haiku state threaded in and out.""" if not base.inside_transform(): raise ValueError( "hk.scan() should not be used outside of hk.transform(). " "Use jax.scan() instead.") if length is None: length = jax.tree_leaves(xs)[0].shape[0] running_init_fn = not base.params_frozen() if running_init_fn: # During `init` we need to unroll one step of the scan, this is because our # carry contains the Haiku state and during `init` this may change structure # (e.g. as state is created). if not length: x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs) _, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype), y0) return init, y0 if reverse: x0 = jax.tree_map(lambda x: x[-1], xs) xs = jax.tree_map(lambda x: x[:-1], xs) else: x0 = jax.tree_map(lambda x: x[0], xs) xs = jax.tree_map(lambda x: x[1:], xs) init, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0) length -= 1 if not length: return init, y0 def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(): carry, out = f(carry, x) carry = (carry, internal_state(params=False)) return carry, out # We know that we don't need to thread params in and out, since for init we # have already created them (given that above we unroll one step of the scan) # and for apply we know they are immutable. As such we only need to thread the # state and rng in and out. init = (init, internal_state(params=False)) (carry, state), ys = jax.lax.scan(stateful_fun, init, xs, length, reverse) update_internal_state(state) if running_init_fn: ys = jax.tree_multimap(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys) return carry, ys
def wrapper(*args, **kwargs): if base.inside_transform(): # fun might be stateful, in which case we need to explicitly thread # state in and out of fun to preserve fun as functionally pure. state = stateful.internal_state() named_f = _named_call(statefulify(fun, state), name=name) out, state = named_f(*args, **kwargs) stateful.update_internal_state(state) else: out = _named_call(fun, name=name)(*args, **kwargs) return out
def eval_shape(fun, *args, **kwargs): """Equivalent to jax.eval_shape with any changed Haiku state discarded.""" if not base.inside_transform(): raise ValueError( "hk.eval_shape() should not be used outside of hk.transform(). " "Use jax.eval_shape() instead.") with temporary_internal_state(internal_state()): out_shape = jax.eval_shape(fun, *args, **kwargs) # Don't update changed state return out_shape
def switch(index, branches, operand): """Equivalent to :func:`jax.lax.switch` but with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError( "hk.switch() should not be used outside of hk.transform(). " "Use jax.switch() instead.") reserve_up_to_full_rng_block() state = internal_state() out, state = jax.lax.switch(index, tuple(map(stateful_branch, branches)), (state, operand)) update_internal_state(state) return out
def cond(pred, true_operand, true_fun, false_operand, false_fun): """Equivalent to ``jax.lax.cond`` but with Haiku state threaded in and out.""" if not base.inside_transform(): raise ValueError("hk.cond() should not be used outside of hk.transform(). " "Use jax.cond() instead.") state = internal_state() out, state = jax.lax.cond(pred, true_operand=(state, true_operand), true_fun=stateful_branch(true_fun), false_operand=(state, false_operand), false_fun=stateful_branch(false_fun)) update_internal_state(state) return out
def wrapper(self, *a, **k): if self._used: # pylint: disable=protected-access raise ValueError("State updater must only be used once.") if not base.inside_transform(): raise ValueError( "State updater must be used inside hk.transform_with_state.") if self._context_id != id(base.current_context()): # pylint: disable=protected-access raise ValueError( "State updater must be used within the same call to init/apply." ) self._used = True # pylint: disable=protected-access return f(self, *a, **k)
def eval_shape(fun, *args, **kwargs): """Equivalent to jax.eval_shape with any changed Haiku state discarded.""" if not base.inside_transform(): raise ValueError( "hk.eval_shape() should not be used outside of hk.transform(). " "Use jax.eval_shape() instead.") @functools.wraps(fun) def stateless_fun(state, *args, **kwargs): with temporary_internal_state(state): out = fun(*args, **kwargs) # Don't return changed state return out out_shape = jax.eval_shape(stateless_fun, internal_state(), *args, **kwargs) return out_shape
def wrapper(*args, **kwargs): side_channel = {"non_jaxtypes": [], "treedef": None} wrapped_fun = hide_non_jaxtype_outputs(fun, side_channel) if base.inside_transform(): wrapped_fun = thread_hk_state_in_kwargs(jax.named_call)( wrapped_fun, name=name) else: wrapped_fun = jax.named_call(wrapped_fun, name=name) jax_types = wrapped_fun(*args, **kwargs) non_jaxtypes = side_channel["non_jaxtypes"] out_leaves = [ y if x is None else x for x, y in zip(jax_types, non_jaxtypes) ] out = jax.tree_unflatten(side_channel["treedef"], out_leaves) return out
def cond(*args, **kwargs): """Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError( "hk.cond() should not be used outside of hk.transform(). " "Use jax.cond() instead.") try: bound_args = inspect.signature(_old_cond).bind(*args, **kwargs) pred, true_operand, true_fun, false_operand, false_fun = bound_args.args if not callable(true_fun) or not callable(false_fun): # Two operand new cond case: cond(pred, tf, ff, 1, 2) raise TypeError except TypeError: bound_args = inspect.signature(_new_cond).bind(*args, **kwargs) bound_args.apply_defaults() pred, true_fun, false_fun, *operands = bound_args.args operand = bound_args.kwargs["operand"] if operand is not SENTINEL: if operands: raise ValueError( "When the operand keyword argument is used you cannot " # pylint: disable=raise-missing-from "also pass operands positionally. Got " f"operand={operand} and *operands={tuple(operands)}") operands = (operand, ) del operand else: true_fun = lambda op, f=true_fun: f(op[0]) false_fun = lambda op, f=false_fun: f(op[1]) operands = ((true_operand, false_operand), ) reserve_up_to_full_rng_block() stateful_branch_mem = _memoize_by_id(stateful_branch) state = internal_state() out, state = jax.lax.cond(pred, true_fun=stateful_branch_mem(true_fun), false_fun=stateful_branch_mem(false_fun), operand=(state, operands)) update_internal_state(state) return out
def scan(f, init, xs, length=None, reverse=False, unroll=1): """Equivalent to :func:`jax.lax.scan` but with Haiku state passed in/out.""" if not base.inside_transform(): raise ValueError( "hk.scan() should not be used outside of hk.transform(). " "Use jax.scan() instead.") if length is None: length = jax.tree_leaves(xs)[0].shape[0] running_init_fn = not base.params_frozen() if running_init_fn: # During `init` we need to unroll one step of the scan, this is because our # carry contains the Haiku state and during `init` this may change structure # (e.g. as state is created). if not length: x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs) _, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype), y0) return init, y0 if reverse: x0 = jax.tree_map(lambda x: x[-1], xs) xs = jax.tree_map(lambda x: x[:-1], xs) else: x0 = jax.tree_map(lambda x: x[0], xs) xs = jax.tree_map(lambda x: x[1:], xs) init, y0 = f(init, x0) y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0) length -= 1 if not length: return init, y0 @functools.wraps(f) def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(), \ base.push_jax_trace_level(): carry, out = f(carry, x) reserve_up_to_full_rng_block() carry = (carry, internal_state(params=False)) return carry, out # Before pulling out the internal state, reserve a full block of RNG keys. # This is to make sure we're always passing in the same amount of subkeys in # and out of the scan carry (scan requires equal length lists). # After every scan iteration we reserve back up to the full block. reserve_up_to_full_rng_block() # We know that we don't need to thread params in and out, since for init we # have already created them (given that above we unroll one step of the scan) # and for apply we know they are immutable. As such we only need to thread the # state and rng in and out. init = (init, internal_state(params=False)) (carry, state), ys = jax.lax.scan(stateful_fun, init, xs, length, reverse, unroll=unroll) update_internal_state(state) if running_init_fn: if reverse: ys = jax.tree_map(lambda y0, ys: jnp.concatenate([ys, y0]), y0, ys) else: ys = jax.tree_map(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys) return carry, ys
def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False): r"""Creates a function which evaluates both ``fun`` and the grad of ``fun``. NOTE: You only need this in a very specific case that you want to take a gradient **inside** a :func:`transform`\ ed function and the function you are differentiating uses :func:`set_state`. For example: >>> class MyModule(hk.Module): ... def __call__(self, x): ... hk.set_state("last", jnp.sum(x)) ... return x ** 2 >>> def f(x): ... m = MyModule() ... y, g = hk.value_and_grad(m)(x) ... return y, g >>> f = hk.transform_with_state(f) >>> x = jnp.array(2.) >>> _ = jax.jit(f.init)(None, x) Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. Returns: A function with the same arguments as ``fun`` that evaluates both ``fun`` and the gradient of ``fun`` and returns them as a pair (a two-element tuple). If ``argnums`` is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. """ if not base.inside_transform(): raise ValueError( "hk.grad() should not be used outside of hk.transform(). " "Use jax.grad() instead.") @functools.wraps(fun) def stateful_fun(*args, **kwargs): state_in = kwargs.pop("hk_state") with temporary_internal_state(state_in): out = fun(*args, **kwargs) out, aux = (out if has_aux else (out, None)) state_out = difference(state_in, internal_state()) return out, (aux, state_out) grad_fun = jax.value_and_grad(stateful_fun, argnums=argnums, has_aux=True, holomorphic=holomorphic) @functools.wraps(grad_fun) def wrapper(*args, **kwargs): kwargs["hk_state"] = internal_state() (value, (aux, hk_state)), grads = grad_fun(*args, **kwargs) update_internal_state(hk_state) if has_aux: return (value, aux), grads else: return value, grads return wrapper