示例#1
0
def fast_eval_shape(fun, *args, **kwargs):
    """Equivalent to ``eval_shape`` in JAX.

  This utility is equivalent to ``eval_shape`` in JAX except that it avoids
  running Haiku functions whose shapes are trivially known. This can avoid some
  Python overheads in JAX which can accumulate for very large models.

  Optimizations:

  * All parameter/state initialisers replaced with zeros.
  * ``hk.dropout`` replaced with identity.
  * ``jax.random.fold_in`` replaced with identity.

  Args:
    fun: The function to trace.
    *args: Positional arguments to ``fun``.
    **kwargs: Keyword arguments to ``fun``.

  Returns:
    The shape produced by ``fun`` for the given args/kwargs.
  """
    with base.custom_creator_unsafe(zeros_creator), \
         mock.patch.object(basic, 'dropout_impl', noop_dropout), \
         mock.patch.object(jax.random, 'fold_in', lambda key, data: key):
        if base.inside_transform():
            return stateful.eval_shape(fun, *args, **kwargs)
        else:
            return jax.eval_shape(fun, *args, **kwargs)
示例#2
0
def cond(*args, **kwargs):
    """Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.cond() should not be used outside of hk.transform(). "
            "Use jax.cond() instead.")

    try:
        bound_args = inspect.signature(_old_cond).bind(*args, **kwargs)
    except TypeError:
        bound_args = inspect.signature(_new_cond).bind(*args, **kwargs)
        pred, true_fun, false_fun, operand = bound_args.args
    else:
        pred, true_operand, true_fun, false_operand, false_fun = bound_args.args
        true_fun = lambda op, f=true_fun: f(op[0])
        false_fun = lambda op, f=false_fun: f(op[1])
        operand = (true_operand, false_operand)

    reserve_up_to_full_rng_block()
    stateful_branch_mem = _memoize_by_id(stateful_branch)
    state = internal_state()
    out, state = jax.lax.cond(pred,
                              true_fun=stateful_branch_mem(true_fun),
                              false_fun=stateful_branch_mem(false_fun),
                              operand=(state, operand))
    update_internal_state(state)
    return out
示例#3
0
def fori_loop(lower, upper, body_fun, init_val):
    """Equivalent to ``jax.lax.fori_loop`` with Haiku state threaded in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.fori_loop() should not be used outside of hk.transform(). "
            "Use jax.lax.fori_loop() instead.")

    def pure_body_fun(i, val):
        state, val = val
        with temporary_internal_state(state):
            val = body_fun(i, val)
            state = internal_state()
            return state, val

    if not base.params_frozen():
        # During init we need to unwind one step of the loop to ensure the Haiku
        # state before and after the body has the same structure.
        init_val = body_fun(lower, init_val)
        lower += 1
        if upper - lower == 0:
            return init_val

    state = internal_state()
    init_val = state, init_val
    state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val)
    update_internal_state(state)
    return val
示例#4
0
def fori_loop(lower, upper, body_fun, init_val):
    """Equivalent to :func:`jax.lax.fori_loop` with Haiku state passed in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.fori_loop() should not be used outside of hk.transform(). "
            "Use jax.lax.fori_loop() instead.")

    def pure_body_fun(i, val):
        state, val = val
        with temporary_internal_state(state):
            val = body_fun(i, val)
            reserve_up_to_full_rng_block()
            state = internal_state()
            return state, val

    if not base.params_frozen():
        # During init we need to unwind one step of the loop to ensure the Haiku
        # state before and after the body has the same structure.
        init_val = body_fun(lower, init_val)
        lower += 1
        try:
            if upper - lower == 0:
                return init_val
        except jax.errors.ConcretizationTypeError:
            # upper or lower might be tracers, which jax.lax.fori_loop can handle.
            pass

    reserve_up_to_full_rng_block()
    state = internal_state()
    init_val = state, init_val
    state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val)
    update_internal_state(state)
    return val
示例#5
0
    def wrapped_dec_fun(fun, *dec_args, **dec_kwargs):
        """Decorates a modified version of ``fun`` that passes Haiku state."""

        if not base.inside_transform():
            raise ValueError(
                "hk.{0}() should not be used outside of hk.transform. "
                "Use jax.{0}() instead.".format(dec_fun.__name__))

        @functools.wraps(fun)
        def stateful_fun(*args, **kwargs):
            state_in = kwargs.pop("hk_state")
            with temporary_internal_state(state_in, share_python_state=True):
                out = fun(*args, **kwargs)
                return out, difference(state_in, internal_state())

        dec_stateful_fun = dec_fun(stateful_fun, *dec_args, **dec_kwargs)

        @functools.wraps(dec_stateful_fun)
        def wrapper(*args, **kwargs):
            kwargs["hk_state"] = internal_state()
            out, state = dec_stateful_fun(*args, **kwargs)
            update_internal_state(state)
            return out

        return wrapper
示例#6
0
 def wrapper(*args, **kwargs):
     if base.inside_transform():
         stateful_named_call = thread_hk_state_in_kwargs(jax.named_call)
         named_fun = stateful_named_call(fun, name=name)
     else:
         named_fun = jax.named_call(fun, name=name)
     out = named_fun(*args, **kwargs)
     return out
示例#7
0
def scan(f, init, xs, length=None, reverse=False):
    """Equivalent to `jax.lax.scan` but with Haiku state threaded in and out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.scan() should not be used outside of hk.transform(). "
            "Use jax.scan() instead.")

    if length is None:
        length = jax.tree_leaves(xs)[0].shape[0]

    running_init_fn = not base.params_frozen()

    if running_init_fn:
        # During `init` we need to unroll one step of the scan, this is because our
        # carry contains the Haiku state and during `init` this may change structure
        # (e.g. as state is created).
        if not length:
            x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs)
            _, y0 = f(init, x0)
            y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype),
                              y0)
            return init, y0

        if reverse:
            x0 = jax.tree_map(lambda x: x[-1], xs)
            xs = jax.tree_map(lambda x: x[:-1], xs)
        else:
            x0 = jax.tree_map(lambda x: x[0], xs)
            xs = jax.tree_map(lambda x: x[1:], xs)
        init, y0 = f(init, x0)
        y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0)
        length -= 1
        if not length:
            return init, y0

    def stateful_fun(carry, x):
        carry, state = carry
        with temporary_internal_state(state):
            with base.assert_no_new_parameters():
                carry, out = f(carry, x)
            carry = (carry, internal_state(params=False))
            return carry, out

    # We know that we don't need to thread params in and out, since for init we
    # have already created them (given that above we unroll one step of the scan)
    # and for apply we know they are immutable. As such we only need to thread the
    # state and rng in and out.
    init = (init, internal_state(params=False))
    (carry, state), ys = jax.lax.scan(stateful_fun, init, xs, length, reverse)
    update_internal_state(state)

    if running_init_fn:
        ys = jax.tree_multimap(lambda y0, ys: jnp.concatenate([y0, ys]), y0,
                               ys)

    return carry, ys
示例#8
0
 def wrapper(*args, **kwargs):
   if base.inside_transform():
     # fun might be stateful, in which case we need to explicitly thread
     # state in and out of fun to preserve fun as functionally pure.
     state = stateful.internal_state()
     named_f = _named_call(statefulify(fun, state), name=name)
     out, state = named_f(*args, **kwargs)
     stateful.update_internal_state(state)
   else:
     out = _named_call(fun, name=name)(*args, **kwargs)
   return out
示例#9
0
def eval_shape(fun, *args, **kwargs):
    """Equivalent to jax.eval_shape with any changed Haiku state discarded."""
    if not base.inside_transform():
        raise ValueError(
            "hk.eval_shape() should not be used outside of hk.transform(). "
            "Use jax.eval_shape() instead.")

    with temporary_internal_state(internal_state()):
        out_shape = jax.eval_shape(fun, *args, **kwargs)
    # Don't update changed state
    return out_shape
示例#10
0
def switch(index, branches, operand):
    """Equivalent to :func:`jax.lax.switch` but with Haiku state passed in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.switch() should not be used outside of hk.transform(). "
            "Use jax.switch() instead.")

    reserve_up_to_full_rng_block()
    state = internal_state()
    out, state = jax.lax.switch(index, tuple(map(stateful_branch, branches)),
                                (state, operand))
    update_internal_state(state)
    return out
示例#11
0
def cond(pred, true_operand, true_fun, false_operand, false_fun):
  """Equivalent to ``jax.lax.cond`` but with Haiku state threaded in and out."""
  if not base.inside_transform():
    raise ValueError("hk.cond() should not be used outside of hk.transform(). "
                     "Use jax.cond() instead.")
  state = internal_state()
  out, state = jax.lax.cond(pred,
                            true_operand=(state, true_operand),
                            true_fun=stateful_branch(true_fun),
                            false_operand=(state, false_operand),
                            false_fun=stateful_branch(false_fun))
  update_internal_state(state)
  return out
示例#12
0
    def wrapper(self, *a, **k):
        if self._used:  # pylint: disable=protected-access
            raise ValueError("State updater must only be used once.")

        if not base.inside_transform():
            raise ValueError(
                "State updater must be used inside hk.transform_with_state.")

        if self._context_id != id(base.current_context()):  # pylint: disable=protected-access
            raise ValueError(
                "State updater must be used within the same call to init/apply."
            )

        self._used = True  # pylint: disable=protected-access

        return f(self, *a, **k)
示例#13
0
def eval_shape(fun, *args, **kwargs):
    """Equivalent to jax.eval_shape with any changed Haiku state discarded."""
    if not base.inside_transform():
        raise ValueError(
            "hk.eval_shape() should not be used outside of hk.transform(). "
            "Use jax.eval_shape() instead.")

    @functools.wraps(fun)
    def stateless_fun(state, *args, **kwargs):
        with temporary_internal_state(state):
            out = fun(*args, **kwargs)
            # Don't return changed state
            return out

    out_shape = jax.eval_shape(stateless_fun, internal_state(), *args,
                               **kwargs)
    return out_shape
示例#14
0
    def wrapper(*args, **kwargs):
        side_channel = {"non_jaxtypes": [], "treedef": None}
        wrapped_fun = hide_non_jaxtype_outputs(fun, side_channel)
        if base.inside_transform():
            wrapped_fun = thread_hk_state_in_kwargs(jax.named_call)(
                wrapped_fun, name=name)
        else:
            wrapped_fun = jax.named_call(wrapped_fun, name=name)

        jax_types = wrapped_fun(*args, **kwargs)

        non_jaxtypes = side_channel["non_jaxtypes"]
        out_leaves = [
            y if x is None else x for x, y in zip(jax_types, non_jaxtypes)
        ]
        out = jax.tree_unflatten(side_channel["treedef"], out_leaves)

        return out
示例#15
0
def cond(*args, **kwargs):
    """Equivalent to :func:`jax.lax.cond` but with Haiku state passed in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.cond() should not be used outside of hk.transform(). "
            "Use jax.cond() instead.")

    try:
        bound_args = inspect.signature(_old_cond).bind(*args, **kwargs)
        pred, true_operand, true_fun, false_operand, false_fun = bound_args.args
        if not callable(true_fun) or not callable(false_fun):
            # Two operand new cond case: cond(pred, tf, ff, 1, 2)
            raise TypeError
    except TypeError:
        bound_args = inspect.signature(_new_cond).bind(*args, **kwargs)
        bound_args.apply_defaults()
        pred, true_fun, false_fun, *operands = bound_args.args
        operand = bound_args.kwargs["operand"]
        if operand is not SENTINEL:
            if operands:
                raise ValueError(
                    "When the operand keyword argument is used you cannot "  # pylint: disable=raise-missing-from
                    "also pass operands positionally. Got "
                    f"operand={operand} and *operands={tuple(operands)}")
            operands = (operand, )
            del operand
    else:
        true_fun = lambda op, f=true_fun: f(op[0])
        false_fun = lambda op, f=false_fun: f(op[1])
        operands = ((true_operand, false_operand), )

    reserve_up_to_full_rng_block()
    stateful_branch_mem = _memoize_by_id(stateful_branch)
    state = internal_state()
    out, state = jax.lax.cond(pred,
                              true_fun=stateful_branch_mem(true_fun),
                              false_fun=stateful_branch_mem(false_fun),
                              operand=(state, operands))
    update_internal_state(state)
    return out
示例#16
0
def scan(f, init, xs, length=None, reverse=False, unroll=1):
    """Equivalent to :func:`jax.lax.scan` but with Haiku state passed in/out."""
    if not base.inside_transform():
        raise ValueError(
            "hk.scan() should not be used outside of hk.transform(). "
            "Use jax.scan() instead.")

    if length is None:
        length = jax.tree_leaves(xs)[0].shape[0]

    running_init_fn = not base.params_frozen()

    if running_init_fn:
        # During `init` we need to unroll one step of the scan, this is because our
        # carry contains the Haiku state and during `init` this may change structure
        # (e.g. as state is created).
        if not length:
            x0 = jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs)
            _, y0 = f(init, x0)
            y0 = jax.tree_map(lambda y: jnp.zeros((0, ) + y.shape, y.dtype),
                              y0)
            return init, y0

        if reverse:
            x0 = jax.tree_map(lambda x: x[-1], xs)
            xs = jax.tree_map(lambda x: x[:-1], xs)
        else:
            x0 = jax.tree_map(lambda x: x[0], xs)
            xs = jax.tree_map(lambda x: x[1:], xs)
        init, y0 = f(init, x0)
        y0 = jax.tree_map(lambda y: jnp.expand_dims(y, 0), y0)
        length -= 1
        if not length:
            return init, y0

    @functools.wraps(f)
    def stateful_fun(carry, x):
        carry, state = carry
        with temporary_internal_state(state):
            with base.assert_no_new_parameters(), \
                 base.push_jax_trace_level():
                carry, out = f(carry, x)
            reserve_up_to_full_rng_block()
            carry = (carry, internal_state(params=False))
            return carry, out

    # Before pulling out the  internal state,  reserve a full block  of RNG keys.
    # This is to make sure we're always passing in the same amount of subkeys in
    # and out of the scan carry (scan requires equal length lists).
    # After every scan iteration we reserve back up to the full block.
    reserve_up_to_full_rng_block()

    # We know that we don't need to thread params in and out, since for init we
    # have already created them (given that above we unroll one step of the scan)
    # and for apply we know they are immutable. As such we only need to thread the
    # state and rng in and out.

    init = (init, internal_state(params=False))
    (carry, state), ys = jax.lax.scan(stateful_fun,
                                      init,
                                      xs,
                                      length,
                                      reverse,
                                      unroll=unroll)
    update_internal_state(state)

    if running_init_fn:
        if reverse:
            ys = jax.tree_map(lambda y0, ys: jnp.concatenate([ys, y0]), y0, ys)
        else:
            ys = jax.tree_map(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys)

    return carry, ys
示例#17
0
def value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False):
    r"""Creates a function which evaluates both ``fun`` and the grad of ``fun``.

  NOTE: You only need this in a very specific case that you want to take a
  gradient **inside** a :func:`transform`\ ed function and the function you are
  differentiating uses :func:`set_state`. For example:

  >>> class MyModule(hk.Module):
  ...   def __call__(self, x):
  ...     hk.set_state("last", jnp.sum(x))
  ...     return x ** 2

  >>> def f(x):
  ...   m = MyModule()
  ...   y, g = hk.value_and_grad(m)(x)
  ...   return y, g

  >>> f = hk.transform_with_state(f)
  >>> x = jnp.array(2.)
  >>> _ = jax.jit(f.init)(None, x)

  Args:
    fun: Function to be differentiated. Its arguments at positions specified by
      ``argnums`` should be arrays, scalars, or standard Python containers. It
      should return a scalar (which includes arrays with shape ``()`` but not
      arrays with shape ``(1,)`` etc.)
    argnums: Optional, integer or tuple of integers. Specifies which positional
      argument(s) to differentiate with respect to (default 0).
    has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
     first element is considered the output of the mathematical function to be
     differentiated and the second element is auxiliary data. Default False.
    holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
      holomorphic. Default False.

  Returns:
    A function with the same arguments as ``fun`` that evaluates both ``fun``
    and the gradient of ``fun`` and returns them as a pair (a two-element
    tuple). If ``argnums`` is an integer then the gradient has the same shape
    and type as the positional argument indicated by that integer. If argnums is
    a tuple of integers, the gradient is a tuple of values with the same shapes
    and types as the corresponding arguments.
  """
    if not base.inside_transform():
        raise ValueError(
            "hk.grad() should not be used outside of hk.transform(). "
            "Use jax.grad() instead.")

    @functools.wraps(fun)
    def stateful_fun(*args, **kwargs):
        state_in = kwargs.pop("hk_state")
        with temporary_internal_state(state_in):
            out = fun(*args, **kwargs)
            out, aux = (out if has_aux else (out, None))
            state_out = difference(state_in, internal_state())
            return out, (aux, state_out)

    grad_fun = jax.value_and_grad(stateful_fun,
                                  argnums=argnums,
                                  has_aux=True,
                                  holomorphic=holomorphic)

    @functools.wraps(grad_fun)
    def wrapper(*args, **kwargs):
        kwargs["hk_state"] = internal_state()
        (value, (aux, hk_state)), grads = grad_fun(*args, **kwargs)
        update_internal_state(hk_state)
        if has_aux:
            return (value, aux), grads
        else:
            return value, grads

    return wrapper