Example #1
def saved_residuals(f, *args,
                    **kwargs) -> List[Tuple[core.AbstractValue, str]]:
    args, in_tree = tree_flatten((args, kwargs))

    def f_(*args):
        args, kwargs = tree_unflatten(in_tree, args)
        return f(*args, **kwargs)

    jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(
    res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
    res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

    results = []

    for x in res_lits:
        results.append((x.aval, 'from a literal'))

    for v in jaxpr.constvars:
        if v in res_vars:
            results.append((v.aval, 'from a constant'))

    assert len(jaxpr.invars) == len(args)
    for i, v in enumerate(jaxpr.invars):
        if v in res_vars:
            src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
            results.append((v.aval, src))

    for eqn in jaxpr.eqns:
        src = source_info_util.summarize(eqn.source_info)
        for v in eqn.outvars:
            if v in res_vars:
                if eqn.primitive is name_p:
                        (v.aval, f"named '{eqn.params['name']}' from {src}"))
                    results.append((v.aval, f'from {src}'))

    assert len(results) == len(jaxpr.outvars)
    return results
Example #2
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
  fun, nz_arg_cts = nonzero_outputs(fun)
  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
  # Preserve axis for primal arguments, skip tangents (represented as undefined primals).
  in_axes, out_axes = params['in_axes'], params['out_axes']
  new_in_axes = (*[axis for axis, x in zip(in_axes, args)
                   if not is_undefined_primal(x)],
                 *[axis for axis, x in zip(out_axes, ct)
                   if type(x) is not Zero])
  # The interim strategy we use below (until avals-with-names) only works
  # when all outputs are mapped.
  assert all(out_axis is not None for out_axis in out_axes), out_axes
  # NOTE: This assumes that the output cotangents being zero is a deterministic
  #       function of which input cotangents were zero.
  @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct)))
  def out_axes_thunk():
    return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
  new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
                    in_axes=new_in_axes, out_axes_thunk=out_axes_thunk)
  del new_params['out_axes']
  update_params = call_transpose_param_updaters.get(primitive)
  if update_params:
    new_params = update_params(new_params, map(is_undefined_primal, args),
                               [type(x) is not Zero for x in ct])
  out_flat = primitive.bind(fun, *all_args, **new_params)
  arg_cts = tree_unflatten(out_tree(), out_flat)

  # The freevars are being fanned out (not mapped). During transpose the
  # dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
  assert len(in_axes) == len(arg_cts)
  def unmap_zero(zero, in_axis):
    return (zero if in_axis is None else
            Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
  arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
             arg_ct if in_axis is not None else
             for arg_ct, in_axis in zip(arg_cts, in_axes))
  return tuple(arg_cts)
Example #3
def map_ildj(prim, incells, outcells, **params):
  """InverseAndILDJ rule for the map primitives."""
  f, incells = incells[0], incells[1:]

  def slice_aval(aval):
    return abstract_arrays.ShapedArray(aval.shape[1:], aval.dtype,

  def add_slice(cell, old_cell):
    new_slices = [
        NDSlice(ndslice.value, ndslice.ildj, Slice(0, old_cell.aval.shape[0]),
                *ndslice.slices) for ndslice in cell.slices
    return InverseAndILDJ(old_cell.aval, new_slices)

  def remove_slice(cell):
    new_slices = [
        NDSlice(ndslice.value, ndslice.ildj, *ndslice.slices[1:])
        for ndslice in cell.slices
    aval = slice_aval(cell.aval)
    return InverseAndILDJ(aval, new_slices)

  mapped_incells = safe_map(remove_slice, incells)
  mapped_outcells = safe_map(remove_slice, outcells)
  flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells))
  f, aux = flat_propagate(f, in_tree)
  # Assume all invars as mapped
  new_mapped_invars = (True,) * len(flat_vals)
  new_params = dict(params, mapped_invars=new_mapped_invars)
  subenv_vals = prim.bind(f, *flat_vals, **new_params)
  subenv_tree = aux()
  subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)
  new_incells = [subenv.read(var) for var in subenv.jaxpr.invars]
  new_outcells = [subenv.read(var) for var in subenv.jaxpr.outvars]
  new_incells = [add_slice(v, old_v)
                 for old_v, v in safe_zip(incells, new_incells)]
  new_outcells = [add_slice(v, old_v)
                  for old_v, v in safe_zip(outcells, new_outcells)]
  return new_incells, new_outcells, subenv
Example #4
def haiku_module(name, nn_module, *, input_shape=None, **kwargs):
    Declare a :mod:`~haiku` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param haiku.Module nn_module: a `haiku` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        import haiku  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use haiku to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `haiku` to be able to use this feature. "
            "It can be installed with `pip install dm-haiku`.") from e

    module_key = name + '$params'
    nn_params = numpyro.param(module_key)
    if nn_params is None:
        args = (jnp.ones(input_shape), ) if input_shape is not None else ()
        # feed in dummy data to init params
        rng_key = numpyro.prng_key()
        nn_params = nn_module.init(rng_key, *args, **kwargs)
        # haiku init returns an immutable dict
        nn_params = haiku.data_structures.to_mutable_dict(nn_params)
        # we cast it to a mutable one to be able to set priors for parameters
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)
    return partial(nn_module.apply, nn_params, None)
Example #5
File: ad.py Project: John1Tang/jax
    def process_call(self, call_primitive, f, tracers, params):
        assert call_primitive.multiple_results
        primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
        nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
        nz_tangents = [type(t) is not Zero for t in tangents]
        if 'name' in params and not config.jax_experimental_name_stack:
            params = dict(params, name=wrap_name(params['name'], 'jvp'))
        f_jvp = jvp_subtrace(f, self.main)
        f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
        if isinstance(call_primitive, core.MapPrimitive):
            in_axes = params['in_axes']
            tangent_in_axes = [
                ax for ax, nz in zip(in_axes, nz_tangents) if nz
            out_axes_thunk = params['out_axes_thunk']
            # The new thunk depends deterministically on the old thunk and the wrapped function.
            # Any caching already has to include the wrapped function as part of the key, so we
            # only use the previous thunk for equality checks.
            # NOTE: This assumes that the output tangents being zero is a deterministic
            #       function of which input tangents were zero.
            @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                return (*out_axes,
                        *(ax for ax, nz in zip(out_axes, nz_tangents_out())
                          if nz))

            params = dict(params,
                          in_axes=(*in_axes, *tangent_in_axes),
        f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
        update_params = call_param_updaters.get(call_primitive)
        new_params = update_params(params,
                                   nz_tangents) if update_params else params
        f_jvp = _update_annotation(f_jvp, f.in_type, nz_tangents)
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents,
        primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
        return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
Example #6
def _scan(
    f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]],
    init: _Carry,
    xs: Iterable[_Input],
) -> Tuple[_Carry, _Output]:
    """Implements an unrolled version of scan.

  Based on `jax.lax.scan` and has a similar API.

  TODO(schsam): We introduce this function because lax.scan currently has a
  higher peak memory usage than the unrolled version. We will aim to swap this
  out for lax.scan when issue #1273 and related have been resolved.
    carry = init
    ys = []
    flat_xs, tree_def = tree_flatten(xs)
    for flat_x in zip(*flat_xs):
        x = tree_unflatten(tree_def, flat_x)
        carry, y = f(carry, x)
        ys += [y]

    return carry, tree_multimap(lambda *y: np.stack(y), *ys)
Example #7
  def start_tracing_body(self):
    """Called upon starting the tracing of the loop body."""
    # TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share.
    self.trace = self.scope.start_subtrace()
    # The entire state is carried.
    self.carried_state_names = sorted(self.scope._mutable_state.keys())
    for key in self.carried_state_names:
      init_val = self.scope._mutable_state[key]
      flat_init_vals, init_tree = tree_util.tree_flatten(init_val)
      flat_init_avals = safe_map(_BodyTracer.abstractify, flat_init_vals)
      flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals)
      flat_init_vars = safe_map(self.trace.new_arg, flat_init_pvals)
      self.carried_state_vars[key] = flat_init_vars
      # Set the scope._mutable_state to new tracing variables.
      self.scope._mutable_state[key] = init_tree.unflatten(flat_init_vars)
      self.scope._mutable_state_aval[key] = init_tree.unflatten(flat_init_avals)
      # Make a copy of the initial state by unflattening the flat_init_vals
      self.carried_state_initial[key] = init_tree.unflatten(flat_init_vals)

    index_var_aval = _BodyTracer.abstractify(0)
    index_var_pval = pe.PartialVal.unknown(index_var_aval)
    self._index_var = self.trace.new_arg(index_var_pval)
Example #8
 def transposed(*args):
     in_primals, out_cts = tree_unflatten(treedef, args)
     in_pvals = [
         if ad.is_undefined_primal(x) else pe.PartialVal.known(x)
         for x in in_primals
     primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
     t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals,
     dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
     in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts,
                               dummy_args, out_cts)
     in_cts_ = iter(in_cts)
     in_cts = [
         if ad.is_undefined_primal(x) else ad_util.Zero(x.aval)
         for x in in_primals
     assert next(in_cts_, None) is None
     in_cts, cell.treedef = tree_flatten(in_cts)
     return in_cts
Example #9
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:
   if not self.jvp:
     msg = "No JVP defined for custom_jvp function {} using defjvp."
     raise AttributeError(msg.format(self.__name__))
   args = _resolve_kwargs(self.fun, args, kwargs)
   if self.nondiff_argnums:
     nondiff_argnums = set(self.nondiff_argnums)
     args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
                  for i, x in enumerate(args))
     diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
     f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args)
     static_args = [args[i] for i in self.nondiff_argnums]
     jvp = _add_args(lu.wrap_init(self.jvp), static_args)
     f_, dyn_args = lu.wrap_init(self.fun), args
     jvp = lu.wrap_init(self.jvp)
   args_flat, in_tree = tree_flatten(dyn_args)
   flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
   flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
   out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
   _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
   return tree_unflatten(out_tree, out_flat)
Example #10
File: ad.py Project: jbampton/jax
def linearize(traceable, *primals, **kwargs):
  has_aux = kwargs.pop('has_aux', False)
  if not has_aux:
    jvpfun = jvp(traceable)
    jvpfun, aux = jvp(traceable, has_aux=True)

  in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
              + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
                    for p in primals))
  _, in_tree = tree_flatten(((primals, primals), {}))
  jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
  assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
  _, out_primals_consts = unzip2(out_primals_pvals)
  jaxpr.invars = jaxpr.invars[len(primals):]
  jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):]
  if not has_aux:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts
    return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
Example #11
def psum(x, axis_name, *, axis_index_groups=None):
  """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  Inputs of boolean dtype are converted to integers before the reduction.

    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    axis_index_groups: optional list of lists containing axis indices (e.g. for
      an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
      two and last two replicas). Groups must cover all axis indices exactly
      once, and all groups must be the same size.

    Array(s) with the same shape as ``x`` representing the result of an
    all-reduce sum along the axis ``axis_name``.

  For example, with 4 XLA devices available:

  >>> x = np.arange(4)
  >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [6 6 6 6]
  >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [ 0.          0.16666667  0.33333334  0.5       ]
  leaves, treedef = tree_util.tree_flatten(x)
  leaves = [lax.convert_element_type(l, onp.int32)
            if dtypes.dtype(l) == onp.bool_ else l for l in leaves]
  out_flat = psum_p.bind(*leaves, axis_name=axis_name,
  return tree_util.tree_unflatten(treedef, out_flat)
Example #12
  def update(updates, state, params=None):
    inner_state = state.inner_state
    flat_updates = tree_flatten(updates)[0]
    isfinite = jnp.all(
        jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
    notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64),
                                1 + state.notfinite_count)

    def do_update(_):
      return inner.update(updates, inner_state, params)
    def reject_update(_):
      return (tree_map(jnp.zeros_like, updates), inner_state)

    updates, new_inner_state = lax.cond(
        jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors),
        do_update, reject_update, operand=None)

    return updates, ApplyIfFiniteState(
        total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite,
Example #13
def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
    Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
    of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions
    to keep memory usage constant.

    :param callable fn: The function to map over.
    :param xs: JAX pytree (e.g. an array, a list/tuple/dict of arrays,...)
    :param int batch_ndims: The number of leading dimensions of `xs`
        to apply `fn` element-wise over them.
    :param int chunk_size: Size of each chunk of `xs`.
        Defaults to the size of batch dimensions.
    :returns: output of `fn(xs)`.
    flatten_xs = tree_flatten(xs)[0]
    batch_shape = np.shape(flatten_xs[0])[:batch_ndims]
    for x in flatten_xs[1:]:
        assert np.shape(x)[:batch_ndims] == batch_shape

    # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...)
    num_chunks = batch_size = int(np.prod(batch_shape))
    prepend_shape = (-1,) if batch_size > 1 else ()
    xs = tree_map(lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]), xs)
    # XXX: probably for the default behavior with chunk_size=None,
    # it is better to catch OOM error and reduce chunk_size by half until OOM disappears.
    chunk_size = batch_size if chunk_size is None else min(batch_size, chunk_size)
    if chunk_size > 1:
        pad = chunk_size - (batch_size % chunk_size)
        xs = tree_map(lambda x: jnp.pad(x, ((0, pad),) + ((0, 0),) * (np.ndim(x) - 1)), xs)
        num_chunks = batch_size // chunk_size + int(pad > 0)
        prepend_shape = (-1,) if num_chunks > 1 else ()
        xs = tree_map(lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]), xs)
        fn = vmap(fn)

    ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    ys = tree_map(lambda y: jnp.reshape(y, (-1,) + jnp.shape(y)[map_ndims:])[:batch_size], ys)
    return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys)
Example #14
def _call_tf_impl(*args_jax_flat, args_treedef, func_tf, out_avals, **_):
    # On GPU we use dlpack to avoid copies of data to the host.
    def _arg_jax_to_tf(arg_jax):
        if (isinstance(arg_jax, xla.DeviceArray)
                and arg_jax.device_buffer.client.platform in _DLPACK_PLATFORMS
                and arg_jax.dtype in dlpack.SUPPORTED_DTYPES):
            arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
            return tf.experimental.dlpack.from_dlpack(arg_dlpack)
        # The following avoids copies to the host on CPU, always for DeviceArray
        # and even for ndarray if they are sufficiently aligned.
        # TODO(necula): on TPU this copies to the host!
        return tf.constant(np.asarray(arg_jax))

    args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
    with jax2tf_internal.inside_call_tf():
        # Call in TF eager mode
        res_tf = func_tf(*args_treedef.unflatten(args_tf_flat))
    res_tf_flat, _ = tree_util.tree_flatten(res_tf)

    # TODO(necula): check the result for tree and aval

    def _res_tf_to_jax(res_tf: TfVal, out_aval: core.AbstractValue):
        res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(
            res_tf, jax_dtype=out_aval.dtype)
        if isinstance(res_tf,
                      tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
            res_tf_platform = tf.DeviceSpec.from_string(
            res_jax_platform = res_tf_platform.lower()
            if res_jax_platform in _DLPACK_PLATFORMS:
                res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
                return jax.dlpack.from_dlpack(

        return jnp.asarray(np.asarray(res_tf))

    return list(map(_res_tf_to_jax, res_tf_flat, out_avals))
Example #15
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:  # pytype: disable=invalid-annotation
     if not self.fwd or not self.bwd:
         msg = "No VJP defined for custom_vjp function {} using defvjp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         for i in self.nondiff_argnums:
         nondiff_argnums = set(self.nondiff_argnums)
         dyn_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun),
         static_args = [args[i] for i in self.nondiff_argnums]
         fwd, _ = argnums_partial(lu.wrap_init(self.fwd),
         bwd = _add_args(lu.wrap_init(self.bwd), static_args)
         f_, dyn_args = lu.wrap_init(self.fun), args
         fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun,
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Example #16
 def process_higher_order_primitive(self, trace, call_primitive, f, tracers,
                                    params, is_map):
     del is_map
     name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap')
     context = trace_util.get_dynamic_context(trace)
     vals = [t.val for t in tracers]
     plants = context.plants
     if 'in_axes' in params:
         # TODO(b/199459308): figure out if invars are mapped or unmapped
         params = dict(params,
                       in_axes=(0, ) * len(tree_util.tree_leaves(plants)) +
     if 'donated_invars' in params:
         params = dict(params)
         params['donated_invars'] = (
             (False, ) * len(tree_util.tree_leaves(plants)) +
     elif call_primitive is nest_p:
         plants = plants.get(params['scope'], {})
     all_vals, all_tree = tree_util.tree_flatten((plants, vals))
     f = plant_eval(f, trace, self.settings, all_tree)
     out_vals = call_primitive.bind(f, *all_vals, name=name, **params)
     return jax_util.safe_map(trace.pure, out_vals)
Example #17
def _sparsify_jaxpr(spenv, jaxpr, *argspecs):
  # TODO(jakevdp): currently this approach discards all information about
  #   shared data & indices when generating the sparsified jaxpr. The
  #   current approach produces valid sparsified while loops, but they
  #   don't work in corner cases (see associated TODO in sparsify_test.py)
  out_tree = None

  def wrapped(*args_flat):
    nonlocal out_tree
    args = tree_unflatten(in_tree, args_flat)
    argspecs = arrays_to_argspecs(spenv, args)
    result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv)
    out = argspecs_to_arrays(spenv, result)
    out_flat, out_tree = tree_flatten(out)
    return out_flat

  args = argspecs_to_arrays(spenv, argspecs)
  args_flat, in_tree = tree_flatten(args)
  avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
  sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
  sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts)
  return sp_jaxpr, out_tree
Example #18
    def jax_shape_for_update(update, shape_like):
        r"""Reshapes grads from array to tree like structure if neccesary for update

            grads: a 1d jax/numpy array
            shape_like: this as in instance having the same type and shape of
                        the desired conversion.

            A possibly non-flat structure of jax arrays containing a copy of data
            compatible with the given shape if jax_available and a copy of update otherwise

        shf, tree = tree_flatten(shape_like)

        updatelist = []
        k = 0
        for s in shf:
            size = s.size
            updatelist.append(jnp.asarray(update[k:k + size]).reshape(s.shape))
            k += size

        return tree_unflatten(tree, updatelist)
Example #19
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
  assert not jaxpr.constvars
  cell = lambda: None

  def transposed(*args):
    in_primals, out_cts = tree_unflatten(treedef, args)
    in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
                pe.PartialVal.known(x) for x in in_primals]
    primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
    tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
    dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
    in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
    in_cts, cell.treedef = tree_flatten(in_cts_)
    return in_cts

  args, treedef = tree_flatten((in_primals, out_cts))
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
  transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
  in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
  return tree_unflatten(cell.treedef, in_cts)  # type: ignore
Example #20
    def trainable_params_by_layer(self) -> Dict[str, int]:
        params, _ = tree_flatten(self)
        layer_names = get_keys(self.dict())
        sizes: Dict[str, int] = {}
        if len(layer_names) != len(set(layer_names)):

            counts = Counter(layer_names)
            incrementer: Counter = Counter()

            for weights, name in zip(params, layer_names):

                if counts[name] > 0:
                    counts[name] -= 1
                    new_name = f"{name}_{incrementer[name]}"
                    incrementer[name] += 1

                    sizes[new_name] = weights.size
            for weights, name in zip(params, layer_names):

                sizes[name] = weights.size

        return sizes
Example #21
def l2_norm_sq(params_tree, parameterization='ntk', W_std=1., b_std=0.05):
    """Parameterization dependent reweighted L2 norm of `params_tree`

    Following Lemma 3 of https://arxiv.org/abs/1806.03335, we reweight weight
    and bias parameters according to their initialisation variances, which are
    `parameterization` dependent.

        params_tree (pytree): Pytree of parameters to take L2 norm of
        parameterization (str): 'ntk' or 'standard'
        W_std (float): Weight standard deviation
        b_std (float): Bias standard deviation

        l2_norm (float): Weighted L2 norm of `params_tree`
    leaves, _ = tree_flatten(params_tree)

    # In NTK parameterisation all parameters have variance 1: easy reweighting
    if parameterization == 'ntk':
        return sum(np.vdot(leaf, leaf) for leaf in leaves)

    # In standard parameterization, need to multiply by `N/W_std^2` for weights
    # and `1/b_std^2` for biases, where N is width.
    # NB this only works for stax.Dense MLPs right now.
    # This is extremely ugly and precarious to the arbitrary choice between the weight
    # matrix and its transpose.
    elif parameterization == 'standard':
        reg_list = []
        for leaf in leaves:
            assert leaf.ndim == 1 or leaf.ndim == 2
            if leaf.ndim == 1:
                reg_list.append(1 / b_std**2)
            elif leaf.ndim == 2:
                reg_list.append(leaf.shape[0] / W_std**2)
        return sum(reg_coef * np.vdot(leaf, leaf)
                   for reg_coef, leaf in zip(reg_list, leaves))
Example #22
def _calc_replica_ids(global_mesh: pxla.Mesh, mesh_axes: MeshAxes):
    pspec = _canonicalize_mesh_axes(mesh_axes)
    mesh_values = list(global_mesh.shape.values())
    flattened_pspec, _ = tree_flatten(tuple(pspec))
    # Get the location (coordinates) of each device in the device mesh.
    device_location = np.array(
        np.unravel_index([d.id for d in global_mesh.devices.flat],
    # Find all the axes that were replicated.
    # If mesh_axes = (('x', 'y'), None, 'z') and ('x', 'y', 'z') were the mesh's
    # axis, then replicated axes will be None since all axes are being used to
    # shard the input.
    replicated_axis = np.isin(list(global_mesh.shape.keys()),
    # If all elements in replicated_axis are False then the input is fully sharded
    # so replica ids should be all 0s.
    if not any(replicated_axis):
        return [0] * global_mesh.devices.size
        # Drop all the sharded axes and find the location of coordinates in a linear
        # array.
        return np.ravel_multi_index(device_location[replicated_axis],
Example #23
def custom_transpose_transpose_rule(cts, *args, call, transpose, out_types,
                                    res_tree, lin_tree, out_tree):
    call_in_tree = treedef_tuple((res_tree, lin_tree))

    # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
    # to which we are transposing (via `ad.is_undefined_primal`).
    # Consider passing this information to the custom transpose rule?

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Example #24
    def numpy_unflatten(self, data, shape_like):
        r"""Attempts a deserialization of the given numpy data.
            This is typically used to deserialize parameters and gradients.

             data: a 1d numpy array.
             shape_like: this as in instance having the same type and shape of
                         the desired conversion.

             A possibly non-flat structure of jax arrays containing a copy of data
             compatible with the given shape.
        shf, tree = tree_flatten(shape_like)

        datalist = []
        k = 0
        for s in shf:
            size = s.size
                jax.numpy.asarray(data[k:k + size]).reshape(s.shape))
            k += size

        return tree_unflatten(tree, datalist)
Example #25
    def result(*args, **kwargs):
        expected = function(*args, **kwargs)

        arguments, _ = j_tree_util.tree_flatten((args, kwargs))

        (result,) = j_core.eval_jaxpr(jaxpr, constants, *arguments)
        assert _are_equal(result, expected)

        (result,) = numpy_function(*args, **kwargs)
        assert _are_equal(result, expected)

            (result,) = numba_function(*args, **kwargs)

            if not catch_numba:


            assert _are_equal(result, expected)

        return expected
Example #26
def default_call_interpreter_rule(primitive: jax_core.CallPrimitive,
                                  rules: Rules, state: Value,
                                  invals: Sequence[Value],
                                  call_jaxpr: jax_core.Jaxpr,
                                  **params: Any) -> Tuple[Value, Value]:
  """Handles simple call primitives like `jax_core.call_p`.

  When evaluating call primitives, the input `state` needs to be an additional
  input to the call primitive and it also needs to return an additional output
  `state`. After flattening the state along with the regular inputs, this
  handler recursively calls `eval_jaxpr_with_state` on the primitive's
  `call_jaxpr`. The output state from the recursive call is returned from the
  call primitive.

    primitive: A `jax_core.CallPrimitive` such as `jax_core.call_p`.
    rules: A `dict` that maps JAX primitives to functions that take in `(state,
      *args)` and return `(output, new_state)`.
    state: The interpreter `state` value at the time of calling evaluating the
      call primitive.
    invals: The input values to the call primitive.
    call_jaxpr: The `jax_core.Jaxpr` that corresponds to the body of the call
    **params: The parameters of the call primitive.

    A tuple of the output of the call primitive and its output state.
  # Recursively use the effect handler for the call primitive's JAXpr.
  fun = lu.wrap_init(
      functools.partial(eval_jaxpr_with_state, call_jaxpr, rules, []))

  state_invals, state_invals_tree = tree_util.tree_flatten((state, *invals))
  flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, state_invals_tree)
  ans_state = primitive.bind(flat_fun, *state_invals, **params)
  return tree_util.tree_unflatten(out_tree(), ans_state)
Example #27
  def update_fn(updates, state, params=None):
    flat_mask, treedef = tree_flatten(mask(updates) if callable(mask) else mask)

    # Flatten then filter out updates/params not in the mask:
    flat_updates = treedef.flatten_up_to(updates)
    masked_updates = [g for g, m in zip(flat_updates, flat_mask) if m]

    if params is not None:
      flat_params = treedef.flatten_up_to(params)
      masked_params = [p for p, m in zip(flat_params, flat_mask) if m]
      masked_params = None

    # Compute new updates
    new_masked_updates, new_inner_state = inner.update(
        masked_updates, state.inner_state, masked_params)

    # Incorporate new_masked_updates into flat_updates, then unflatten
    new_masked_updates = iter(new_masked_updates)
    for i, m in enumerate(flat_mask):
      if m: flat_updates[i] = next(new_masked_updates)

    new_updates = treedef.unflatten(flat_updates)
    return new_updates, MaskedState(inner_state=new_inner_state)
Example #28
def sow(value, *, tag, name, mode='strict', key=None):
  """Marks a value with a name and a tag.

    value: A JAX value to be tagged and named.
    tag (str): a string representing the tag of the sown value.
    name (str): a string representing the name to sow the value with.
    mode (str): The mode by which to sow the value. There are three options: 1.
      strict - if another value is sown with the same name and tag in the same
      context, harvest will throw an error. 2. clobber - if another is value is
      sown with the same name and tag, it will replace this value 3. append -
      sown values of the same name and tag are appended to a growing list.
      Append mode assumes some ordering on the values being sown defined by
    key: an optional JAX value that will be tied into the sown value.

    The original `value` that was passed in.
  if key is not None:
    value = prim.tie_in(key, value)
  flat_args, in_tree = tree_util.tree_flatten(value)
  out_flat = sow_p.bind(*flat_args, name=name, tag=tag, mode=mode, tree=in_tree)
  return tree_util.tree_unflatten(in_tree, out_flat)
Example #29
def l2_norm(tree):
    """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
    leaves, _ = tree_flatten(tree)
    return np.sqrt(sum(np.vdot(x, x) for x in leaves))
Example #30
 def tree_init(x0_tree):
     x0_flat, tree = tree_flatten(x0_tree)
     initial_states = [init(x0) for x0 in x0_flat]
     states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
     packed_state = pack(map(pack, states_flat))
     return OptimizerState(packed_state, tree, subtrees)