Beispiel #1
0
def saved_residuals(f, *args,
                    **kwargs) -> List[Tuple[core.AbstractValue, str]]:
    args, in_tree = tree_flatten((args, kwargs))

    def f_(*args):
        args, kwargs = tree_unflatten(in_tree, args)
        return f(*args, **kwargs)

    jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(
        *args).jaxpr
    res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
    res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

    results = []

    for x in res_lits:
        results.append((x.aval, 'from a literal'))

    for v in jaxpr.constvars:
        if v in res_vars:
            results.append((v.aval, 'from a constant'))

    assert len(jaxpr.invars) == len(args)
    for i, v in enumerate(jaxpr.invars):
        if v in res_vars:
            src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
            results.append((v.aval, src))

    for eqn in jaxpr.eqns:
        src = source_info_util.summarize(eqn.source_info)
        for v in eqn.outvars:
            if v in res_vars:
                if eqn.primitive is name_p:
                    results.append(
                        (v.aval, f"named '{eqn.params['name']}' from {src}"))
                else:
                    results.append((v.aval, f'from {src}'))

    assert len(results) == len(jaxpr.outvars)
    return results
Beispiel #2
0
def map_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr, reduce_axes, False)
  fun, nz_arg_cts = nonzero_outputs(fun)
  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
  # Preserve axis for primal arguments, skip tangents (represented as undefined primals).
  in_axes, out_axes = params['in_axes'], params['out_axes']
  new_in_axes = (*[axis for axis, x in zip(in_axes, args)
                   if not is_undefined_primal(x)],
                 *[axis for axis, x in zip(out_axes, ct)
                   if type(x) is not Zero])
  # The interim strategy we use below (until avals-with-names) only works
  # when all outputs are mapped.
  assert all(out_axis is not None for out_axis in out_axes), out_axes
  # NOTE: This assumes that the output cotangents being zero is a deterministic
  #       function of which input cotangents were zero.
  @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct)))
  def out_axes_thunk():
    return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
  new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
                    in_axes=new_in_axes, out_axes_thunk=out_axes_thunk)
  del new_params['out_axes']
  update_params = call_transpose_param_updaters.get(primitive)
  if update_params:
    new_params = update_params(new_params, map(is_undefined_primal, args),
                               [type(x) is not Zero for x in ct])
  out_flat = primitive.bind(fun, *all_args, **new_params)
  arg_cts = tree_unflatten(out_tree(), out_flat)

  # The freevars are being fanned out (not mapped). During transpose the
  # dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
  assert len(in_axes) == len(arg_cts)
  def unmap_zero(zero, in_axis):
    return (zero if in_axis is None else
            Zero(core.unmapped_aval(params['axis_size'], params['axis_name'], in_axis, zero.aval)))
  arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
             arg_ct if in_axis is not None else
             arg_ct.sum(0)
             for arg_ct, in_axis in zip(arg_cts, in_axes))
  return tuple(arg_cts)
Beispiel #3
0
def map_ildj(prim, incells, outcells, **params):
  """InverseAndILDJ rule for the map primitives."""
  f, incells = incells[0], incells[1:]

  def slice_aval(aval):
    return abstract_arrays.ShapedArray(aval.shape[1:], aval.dtype,
                                       aval.weak_type)

  def add_slice(cell, old_cell):
    new_slices = [
        NDSlice(ndslice.value, ndslice.ildj, Slice(0, old_cell.aval.shape[0]),
                *ndslice.slices) for ndslice in cell.slices
    ]
    return InverseAndILDJ(old_cell.aval, new_slices)

  def remove_slice(cell):
    new_slices = [
        NDSlice(ndslice.value, ndslice.ildj, *ndslice.slices[1:])
        for ndslice in cell.slices
    ]
    aval = slice_aval(cell.aval)
    return InverseAndILDJ(aval, new_slices)

  mapped_incells = safe_map(remove_slice, incells)
  mapped_outcells = safe_map(remove_slice, outcells)
  flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells))
  f, aux = flat_propagate(f, in_tree)
  # Assume all invars as mapped
  new_mapped_invars = (True,) * len(flat_vals)
  new_params = dict(params, mapped_invars=new_mapped_invars)
  subenv_vals = prim.bind(f, *flat_vals, **new_params)
  subenv_tree = aux()
  subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)
  new_incells = [subenv.read(var) for var in subenv.jaxpr.invars]
  new_outcells = [subenv.read(var) for var in subenv.jaxpr.outvars]
  new_incells = [add_slice(v, old_v)
                 for old_v, v in safe_zip(incells, new_incells)]
  new_outcells = [add_slice(v, old_v)
                  for old_v, v in safe_zip(outcells, new_outcells)]
  return new_incells, new_outcells, subenv
Beispiel #4
0
def haiku_module(name, nn_module, *, input_shape=None, **kwargs):
    """
    Declare a :mod:`~haiku` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param haiku.Module nn_module: a `haiku` Module which has .init and .apply methods
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :param kwargs: optional keyword arguments to initialize flax neural network
        as an alternative to `input_shape`
    :return: a callable with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    try:
        import haiku  # noqa: F401
    except ImportError as e:
        raise ImportError(
            "Looking like you want to use haiku to declare "
            "nn modules. This is an experimental feature. "
            "You need to install `haiku` to be able to use this feature. "
            "It can be installed with `pip install dm-haiku`.") from e

    module_key = name + '$params'
    nn_params = numpyro.param(module_key)
    if nn_params is None:
        args = (jnp.ones(input_shape), ) if input_shape is not None else ()
        # feed in dummy data to init params
        rng_key = numpyro.prng_key()
        nn_params = nn_module.init(rng_key, *args, **kwargs)
        # haiku init returns an immutable dict
        nn_params = haiku.data_structures.to_mutable_dict(nn_params)
        # we cast it to a mutable one to be able to set priors for parameters
        # make sure that nn_params keep the same order after unflatten
        params_flat, tree_def = tree_flatten(nn_params)
        nn_params = tree_unflatten(tree_def, params_flat)
        numpyro.param(module_key, nn_params)
    return partial(nn_module.apply, nn_params, None)
Beispiel #5
0
    def process_call(self, call_primitive, f, tracers, params):
        assert call_primitive.multiple_results
        primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
        nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
        nz_tangents = [type(t) is not Zero for t in tangents]
        if 'name' in params and not config.jax_experimental_name_stack:
            params = dict(params, name=wrap_name(params['name'], 'jvp'))
        f_jvp = jvp_subtrace(f, self.main)
        f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
        if isinstance(call_primitive, core.MapPrimitive):
            in_axes = params['in_axes']
            tangent_in_axes = [
                ax for ax, nz in zip(in_axes, nz_tangents) if nz
            ]
            out_axes_thunk = params['out_axes_thunk']
            # The new thunk depends deterministically on the old thunk and the wrapped function.
            # Any caching already has to include the wrapped function as part of the key, so we
            # only use the previous thunk for equality checks.
            # NOTE: This assumes that the output tangents being zero is a deterministic
            #       function of which input tangents were zero.
            @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                return (*out_axes,
                        *(ax for ax, nz in zip(out_axes, nz_tangents_out())
                          if nz))

            params = dict(params,
                          in_axes=(*in_axes, *tangent_in_axes),
                          out_axes_thunk=new_out_axes_thunk)
        f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
        update_params = call_param_updaters.get(call_primitive)
        new_params = update_params(params,
                                   nz_tangents) if update_params else params
        f_jvp = _update_annotation(f_jvp, f.in_type, nz_tangents)
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents,
                                     **new_params)
        primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
        return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
Beispiel #6
0
def _scan(
    f: Callable[[_Carry, _Input], Tuple[_Carry, _Output]],
    init: _Carry,
    xs: Iterable[_Input],
) -> Tuple[_Carry, _Output]:
    """Implements an unrolled version of scan.

  Based on `jax.lax.scan` and has a similar API.

  TODO(schsam): We introduce this function because lax.scan currently has a
  higher peak memory usage than the unrolled version. We will aim to swap this
  out for lax.scan when issue #1273 and related have been resolved.
  """
    carry = init
    ys = []
    flat_xs, tree_def = tree_flatten(xs)
    for flat_x in zip(*flat_xs):
        x = tree_unflatten(tree_def, flat_x)
        carry, y = f(carry, x)
        ys += [y]

    return carry, tree_multimap(lambda *y: np.stack(y), *ys)
Beispiel #7
0
  def start_tracing_body(self):
    """Called upon starting the tracing of the loop body."""
    # TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share.
    self.trace = self.scope.start_subtrace()
    # The entire state is carried.
    self.carried_state_names = sorted(self.scope._mutable_state.keys())
    for key in self.carried_state_names:
      init_val = self.scope._mutable_state[key]
      flat_init_vals, init_tree = tree_util.tree_flatten(init_val)
      flat_init_avals = safe_map(_BodyTracer.abstractify, flat_init_vals)
      flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals)
      flat_init_vars = safe_map(self.trace.new_arg, flat_init_pvals)
      self.carried_state_vars[key] = flat_init_vars
      # Set the scope._mutable_state to new tracing variables.
      self.scope._mutable_state[key] = init_tree.unflatten(flat_init_vars)
      self.scope._mutable_state_aval[key] = init_tree.unflatten(flat_init_avals)
      # Make a copy of the initial state by unflattening the flat_init_vals
      self.carried_state_initial[key] = init_tree.unflatten(flat_init_vals)

    index_var_aval = _BodyTracer.abstractify(0)
    index_var_pval = pe.PartialVal.unknown(index_var_aval)
    self._index_var = self.trace.new_arg(index_var_pval)
Beispiel #8
0
 def transposed(*args):
     in_primals, out_cts = tree_unflatten(treedef, args)
     in_pvals = [
         pe.PartialVal.unknown(x.aval)
         if ad.is_undefined_primal(x) else pe.PartialVal.known(x)
         for x in in_primals
     ]
     primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
     t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals,
                                                    False)
     dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
     in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts,
                               dummy_args, out_cts)
     in_cts_ = iter(in_cts)
     in_cts = [
         next(in_cts_)
         if ad.is_undefined_primal(x) else ad_util.Zero(x.aval)
         for x in in_primals
     ]
     assert next(in_cts_, None) is None
     in_cts, cell.treedef = tree_flatten(in_cts)
     return in_cts
Beispiel #9
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:
   if not self.jvp:
     msg = "No JVP defined for custom_jvp function {} using defjvp."
     raise AttributeError(msg.format(self.__name__))
   args = _resolve_kwargs(self.fun, args, kwargs)
   if self.nondiff_argnums:
     nondiff_argnums = set(self.nondiff_argnums)
     args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
                  for i, x in enumerate(args))
     diff_argnums = [i for i in range(len(args)) if i not in nondiff_argnums]
     f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), diff_argnums, args)
     static_args = [args[i] for i in self.nondiff_argnums]
     jvp = _add_args(lu.wrap_init(self.jvp), static_args)
   else:
     f_, dyn_args = lu.wrap_init(self.fun), args
     jvp = lu.wrap_init(self.jvp)
   args_flat, in_tree = tree_flatten(dyn_args)
   flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
   flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
   out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
   _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
   return tree_unflatten(out_tree, out_flat)
Beispiel #10
0
def linearize(traceable, *primals, **kwargs):
  has_aux = kwargs.pop('has_aux', False)
  if not has_aux:
    jvpfun = jvp(traceable)
  else:
    jvpfun, aux = jvp(traceable, has_aux=True)

  in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
              + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
                    for p in primals))
  _, in_tree = tree_flatten(((primals, primals), {}))
  jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
  assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
  _, out_primals_consts = unzip2(out_primals_pvals)
  jaxpr.invars = jaxpr.invars[len(primals):]
  jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):]
  if not has_aux:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts
  else:
    return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
Beispiel #11
0
def psum(x, axis_name, *, axis_index_groups=None):
  """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  Inputs of boolean dtype are converted to integers before the reduction.

  Args:
    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    axis_index_groups: optional list of lists containing axis indices (e.g. for
      an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
      two and last two replicas). Groups must cover all axis indices exactly
      once, and all groups must be the same size.


  Returns:
    Array(s) with the same shape as ``x`` representing the result of an
    all-reduce sum along the axis ``axis_name``.

  For example, with 4 XLA devices available:

  >>> x = np.arange(4)
  >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [6 6 6 6]
  >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [ 0.          0.16666667  0.33333334  0.5       ]
  """
  _validate_axis_index_groups(axis_index_groups)
  leaves, treedef = tree_util.tree_flatten(x)
  leaves = [lax.convert_element_type(l, onp.int32)
            if dtypes.dtype(l) == onp.bool_ else l for l in leaves]
  out_flat = psum_p.bind(*leaves, axis_name=axis_name,
                         axis_index_groups=axis_index_groups)
  return tree_util.tree_unflatten(treedef, out_flat)
Beispiel #12
0
  def update(updates, state, params=None):
    inner_state = state.inner_state
    flat_updates = tree_flatten(updates)[0]
    isfinite = jnp.all(
        jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
    notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64),
                                1 + state.notfinite_count)

    def do_update(_):
      return inner.update(updates, inner_state, params)
    def reject_update(_):
      return (tree_map(jnp.zeros_like, updates), inner_state)

    updates, new_inner_state = lax.cond(
        jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors),
        do_update, reject_update, operand=None)

    return updates, ApplyIfFiniteState(
        notfinite_count=notfinite_count,
        last_finite=isfinite,
        total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite,
        inner_state=new_inner_state)
Beispiel #13
0
def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
    """
    Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
    of `xs`. This uses jax.vmap over smaller chunks of the batch dimensions
    to keep memory usage constant.

    :param callable fn: The function to map over.
    :param xs: JAX pytree (e.g. an array, a list/tuple/dict of arrays,...)
    :param int batch_ndims: The number of leading dimensions of `xs`
        to apply `fn` element-wise over them.
    :param int chunk_size: Size of each chunk of `xs`.
        Defaults to the size of batch dimensions.
    :returns: output of `fn(xs)`.
    """
    flatten_xs = tree_flatten(xs)[0]
    batch_shape = np.shape(flatten_xs[0])[:batch_ndims]
    for x in flatten_xs[1:]:
        assert np.shape(x)[:batch_ndims] == batch_shape

    # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...)
    num_chunks = batch_size = int(np.prod(batch_shape))
    prepend_shape = (-1,) if batch_size > 1 else ()
    xs = tree_map(lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]), xs)
    # XXX: probably for the default behavior with chunk_size=None,
    # it is better to catch OOM error and reduce chunk_size by half until OOM disappears.
    chunk_size = batch_size if chunk_size is None else min(batch_size, chunk_size)
    if chunk_size > 1:
        pad = chunk_size - (batch_size % chunk_size)
        xs = tree_map(lambda x: jnp.pad(x, ((0, pad),) + ((0, 0),) * (np.ndim(x) - 1)), xs)
        num_chunks = batch_size // chunk_size + int(pad > 0)
        prepend_shape = (-1,) if num_chunks > 1 else ()
        xs = tree_map(lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]), xs)
        fn = vmap(fn)

    ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    ys = tree_map(lambda y: jnp.reshape(y, (-1,) + jnp.shape(y)[map_ndims:])[:batch_size], ys)
    return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys)
Beispiel #14
0
def _call_tf_impl(*args_jax_flat, args_treedef, func_tf, out_avals, **_):
    # On GPU we use dlpack to avoid copies of data to the host.
    def _arg_jax_to_tf(arg_jax):
        if (isinstance(arg_jax, xla.DeviceArray)
                and arg_jax.device_buffer.client.platform in _DLPACK_PLATFORMS
                and arg_jax.dtype in dlpack.SUPPORTED_DTYPES):
            arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
            return tf.experimental.dlpack.from_dlpack(arg_dlpack)
        # The following avoids copies to the host on CPU, always for DeviceArray
        # and even for ndarray if they are sufficiently aligned.
        # TODO(necula): on TPU this copies to the host!
        return tf.constant(np.asarray(arg_jax))

    args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat))
    with jax2tf_internal.inside_call_tf():
        # Call in TF eager mode
        res_tf = func_tf(*args_treedef.unflatten(args_tf_flat))
    res_tf_flat, _ = tree_util.tree_flatten(res_tf)

    # TODO(necula): check the result for tree and aval

    def _res_tf_to_jax(res_tf: TfVal, out_aval: core.AbstractValue):
        res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(
            res_tf, jax_dtype=out_aval.dtype)
        if isinstance(res_tf,
                      tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
            res_tf_platform = tf.DeviceSpec.from_string(
                res_tf.backing_device).device_type
            res_jax_platform = res_tf_platform.lower()
            if res_jax_platform in _DLPACK_PLATFORMS:
                res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
                return jax.dlpack.from_dlpack(
                    res_dlpack,
                    backend=xla_bridge.get_backend(res_jax_platform))

        return jnp.asarray(np.asarray(res_tf))

    return list(map(_res_tf_to_jax, res_tf_flat, out_avals))
Beispiel #15
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:  # pytype: disable=invalid-annotation
     if not self.fwd or not self.bwd:
         msg = "No VJP defined for custom_vjp function {} using defvjp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         for i in self.nondiff_argnums:
             _check_for_tracers(args[i])
         nondiff_argnums = set(self.nondiff_argnums)
         dyn_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun),
                                        dyn_argnums,
                                        args,
                                        require_static_args_hashable=False)
         static_args = [args[i] for i in self.nondiff_argnums]
         fwd, _ = argnums_partial(lu.wrap_init(self.fwd),
                                  dyn_argnums,
                                  args,
                                  require_static_args_hashable=False)
         bwd = _add_args(lu.wrap_init(self.bwd), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun,
                                       flat_fwd,
                                       flat_bwd,
                                       *args_flat,
                                       out_trees=out_trees)
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Beispiel #16
0
 def process_higher_order_primitive(self, trace, call_primitive, f, tracers,
                                    params, is_map):
     del is_map
     name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap')
     context = trace_util.get_dynamic_context(trace)
     vals = [t.val for t in tracers]
     plants = context.plants
     if 'in_axes' in params:
         # TODO(b/199459308): figure out if invars are mapped or unmapped
         params = dict(params,
                       in_axes=(0, ) * len(tree_util.tree_leaves(plants)) +
                       params['in_axes'])
     if 'donated_invars' in params:
         params = dict(params)
         params['donated_invars'] = (
             (False, ) * len(tree_util.tree_leaves(plants)) +
             params['donated_invars'])
     elif call_primitive is nest_p:
         plants = plants.get(params['scope'], {})
     all_vals, all_tree = tree_util.tree_flatten((plants, vals))
     f = plant_eval(f, trace, self.settings, all_tree)
     out_vals = call_primitive.bind(f, *all_vals, name=name, **params)
     return jax_util.safe_map(trace.pure, out_vals)
Beispiel #17
0
def _sparsify_jaxpr(spenv, jaxpr, *argspecs):
  # TODO(jakevdp): currently this approach discards all information about
  #   shared data & indices when generating the sparsified jaxpr. The
  #   current approach produces valid sparsified while loops, but they
  #   don't work in corner cases (see associated TODO in sparsify_test.py)
  out_tree = None

  @lu.wrap_init
  def wrapped(*args_flat):
    nonlocal out_tree
    args = tree_unflatten(in_tree, args_flat)
    argspecs = arrays_to_argspecs(spenv, args)
    result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv)
    out = argspecs_to_arrays(spenv, result)
    out_flat, out_tree = tree_flatten(out)
    return out_flat

  args = argspecs_to_arrays(spenv, argspecs)
  args_flat, in_tree = tree_flatten(args)
  avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
  sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
  sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts)
  return sp_jaxpr, out_tree
Beispiel #18
0
    def jax_shape_for_update(update, shape_like):
        r"""Reshapes grads from array to tree like structure if neccesary for update

        Args:
            grads: a 1d jax/numpy array
            shape_like: this as in instance having the same type and shape of
                        the desired conversion.

        Returns:
            A possibly non-flat structure of jax arrays containing a copy of data
            compatible with the given shape if jax_available and a copy of update otherwise
        """

        shf, tree = tree_flatten(shape_like)

        updatelist = []
        k = 0
        for s in shf:
            size = s.size
            updatelist.append(jnp.asarray(update[k:k + size]).reshape(s.shape))
            k += size

        return tree_unflatten(tree, updatelist)
Beispiel #19
0
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
  assert not jaxpr.constvars
  cell = lambda: None

  @lu.wrap_init
  def transposed(*args):
    in_primals, out_cts = tree_unflatten(treedef, args)
    in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
                pe.PartialVal.known(x) for x in in_primals]
    primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
    tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
    dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
    in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
                               out_cts)
    in_cts, cell.treedef = tree_flatten(in_cts_)
    return in_cts

  args, treedef = tree_flatten((in_primals, out_cts))
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
  transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
  in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
  return tree_unflatten(cell.treedef, in_cts)  # type: ignore
Beispiel #20
0
    def trainable_params_by_layer(self) -> Dict[str, int]:
        params, _ = tree_flatten(self)
        layer_names = get_keys(self.dict())
        sizes: Dict[str, int] = {}
        if len(layer_names) != len(set(layer_names)):

            counts = Counter(layer_names)
            incrementer: Counter = Counter()

            for weights, name in zip(params, layer_names):

                if counts[name] > 0:
                    counts[name] -= 1
                    new_name = f"{name}_{incrementer[name]}"
                    incrementer[name] += 1

                    sizes[new_name] = weights.size
        else:
            for weights, name in zip(params, layer_names):

                sizes[name] = weights.size

        return sizes
Beispiel #21
0
def l2_norm_sq(params_tree, parameterization='ntk', W_std=1., b_std=0.05):
    """Parameterization dependent reweighted L2 norm of `params_tree`

    Following Lemma 3 of https://arxiv.org/abs/1806.03335, we reweight weight
    and bias parameters according to their initialisation variances, which are
    `parameterization` dependent.

    Args:
        params_tree (pytree): Pytree of parameters to take L2 norm of
        parameterization (str): 'ntk' or 'standard'
        W_std (float): Weight standard deviation
        b_std (float): Bias standard deviation

    Returns:
        l2_norm (float): Weighted L2 norm of `params_tree`
    """
    leaves, _ = tree_flatten(params_tree)

    # In NTK parameterisation all parameters have variance 1: easy reweighting
    if parameterization == 'ntk':
        return sum(np.vdot(leaf, leaf) for leaf in leaves)

    # In standard parameterization, need to multiply by `N/W_std^2` for weights
    # and `1/b_std^2` for biases, where N is width.
    # NB this only works for stax.Dense MLPs right now.
    # This is extremely ugly and precarious to the arbitrary choice between the weight
    # matrix and its transpose.
    elif parameterization == 'standard':
        reg_list = []
        for leaf in leaves:
            assert leaf.ndim == 1 or leaf.ndim == 2
            if leaf.ndim == 1:
                reg_list.append(1 / b_std**2)
            elif leaf.ndim == 2:
                reg_list.append(leaf.shape[0] / W_std**2)
        return sum(reg_coef * np.vdot(leaf, leaf)
                   for reg_coef, leaf in zip(reg_list, leaves))
Beispiel #22
0
def _calc_replica_ids(global_mesh: pxla.Mesh, mesh_axes: MeshAxes):
    pspec = _canonicalize_mesh_axes(mesh_axes)
    mesh_values = list(global_mesh.shape.values())
    flattened_pspec, _ = tree_flatten(tuple(pspec))
    # Get the location (coordinates) of each device in the device mesh.
    device_location = np.array(
        np.unravel_index([d.id for d in global_mesh.devices.flat],
                         mesh_values))
    # Find all the axes that were replicated.
    # If mesh_axes = (('x', 'y'), None, 'z') and ('x', 'y', 'z') were the mesh's
    # axis, then replicated axes will be None since all axes are being used to
    # shard the input.
    replicated_axis = np.isin(list(global_mesh.shape.keys()),
                              flattened_pspec,
                              invert=True)
    # If all elements in replicated_axis are False then the input is fully sharded
    # so replica ids should be all 0s.
    if not any(replicated_axis):
        return [0] * global_mesh.devices.size
    else:
        # Drop all the sharded axes and find the location of coordinates in a linear
        # array.
        return np.ravel_multi_index(device_location[replicated_axis],
                                    np.array(mesh_values)[replicated_axis])
Beispiel #23
0
def custom_transpose_transpose_rule(cts, *args, call, transpose, out_types,
                                    res_tree, lin_tree, out_tree):
    call_in_tree = treedef_tuple((res_tree, lin_tree))

    # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
    # to which we are transposing (via `ad.is_undefined_primal`).
    # Consider passing this information to the custom transpose rule?

    res_arg, lin_arg = tree_unflatten(call_in_tree, args)
    del lin_arg
    assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

    cts = [
        ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
        for ct in cts
    ]
    ct_out = tree_unflatten(out_tree, cts)
    ct_lin = transpose(res_arg, ct_out)
    check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
    ct_lin_flat, _ = tree_flatten(tree_broadcast(lin_tree,
                                                 ct_lin,
                                                 is_leaf=lambda x: x is None),
                                  is_leaf=lambda x: x is None)
    return [None] * len(tree_leaves(res_arg)) + ct_lin_flat
Beispiel #24
0
    def numpy_unflatten(self, data, shape_like):
        r"""Attempts a deserialization of the given numpy data.
            This is typically used to deserialize parameters and gradients.

        Args:
             data: a 1d numpy array.
             shape_like: this as in instance having the same type and shape of
                         the desired conversion.

        Returns:
             A possibly non-flat structure of jax arrays containing a copy of data
             compatible with the given shape.
        """
        shf, tree = tree_flatten(shape_like)

        datalist = []
        k = 0
        for s in shf:
            size = s.size
            datalist.append(
                jax.numpy.asarray(data[k:k + size]).reshape(s.shape))
            k += size

        return tree_unflatten(tree, datalist)
Beispiel #25
0
    def result(*args, **kwargs):
        expected = function(*args, **kwargs)

        arguments, _ = j_tree_util.tree_flatten((args, kwargs))

        (result,) = j_core.eval_jaxpr(jaxpr, constants, *arguments)
        assert _are_equal(result, expected)

        (result,) = numpy_function(*args, **kwargs)
        assert _are_equal(result, expected)

        try:
            (result,) = numba_function(*args, **kwargs)

        except:
            if not catch_numba:
                raise

            traceback.print_exc()

        else:
            assert _are_equal(result, expected)

        return expected
Beispiel #26
0
def default_call_interpreter_rule(primitive: jax_core.CallPrimitive,
                                  rules: Rules, state: Value,
                                  invals: Sequence[Value],
                                  call_jaxpr: jax_core.Jaxpr,
                                  **params: Any) -> Tuple[Value, Value]:
  """Handles simple call primitives like `jax_core.call_p`.

  When evaluating call primitives, the input `state` needs to be an additional
  input to the call primitive and it also needs to return an additional output
  `state`. After flattening the state along with the regular inputs, this
  handler recursively calls `eval_jaxpr_with_state` on the primitive's
  `call_jaxpr`. The output state from the recursive call is returned from the
  call primitive.

  Args:
    primitive: A `jax_core.CallPrimitive` such as `jax_core.call_p`.
    rules: A `dict` that maps JAX primitives to functions that take in `(state,
      *args)` and return `(output, new_state)`.
    state: The interpreter `state` value at the time of calling evaluating the
      call primitive.
    invals: The input values to the call primitive.
    call_jaxpr: The `jax_core.Jaxpr` that corresponds to the body of the call
      primitive.
    **params: The parameters of the call primitive.

  Returns:
    A tuple of the output of the call primitive and its output state.
  """
  # Recursively use the effect handler for the call primitive's JAXpr.
  fun = lu.wrap_init(
      functools.partial(eval_jaxpr_with_state, call_jaxpr, rules, []))

  state_invals, state_invals_tree = tree_util.tree_flatten((state, *invals))
  flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, state_invals_tree)
  ans_state = primitive.bind(flat_fun, *state_invals, **params)
  return tree_util.tree_unflatten(out_tree(), ans_state)
Beispiel #27
0
  def update_fn(updates, state, params=None):
    flat_mask, treedef = tree_flatten(mask(updates) if callable(mask) else mask)

    # Flatten then filter out updates/params not in the mask:
    flat_updates = treedef.flatten_up_to(updates)
    masked_updates = [g for g, m in zip(flat_updates, flat_mask) if m]

    if params is not None:
      flat_params = treedef.flatten_up_to(params)
      masked_params = [p for p, m in zip(flat_params, flat_mask) if m]
    else:
      masked_params = None

    # Compute new updates
    new_masked_updates, new_inner_state = inner.update(
        masked_updates, state.inner_state, masked_params)

    # Incorporate new_masked_updates into flat_updates, then unflatten
    new_masked_updates = iter(new_masked_updates)
    for i, m in enumerate(flat_mask):
      if m: flat_updates[i] = next(new_masked_updates)

    new_updates = treedef.unflatten(flat_updates)
    return new_updates, MaskedState(inner_state=new_inner_state)
Beispiel #28
0
def sow(value, *, tag, name, mode='strict', key=None):
  """Marks a value with a name and a tag.

  Args:
    value: A JAX value to be tagged and named.
    tag (str): a string representing the tag of the sown value.
    name (str): a string representing the name to sow the value with.
    mode (str): The mode by which to sow the value. There are three options: 1.
      strict - if another value is sown with the same name and tag in the same
      context, harvest will throw an error. 2. clobber - if another is value is
      sown with the same name and tag, it will replace this value 3. append -
      sown values of the same name and tag are appended to a growing list.
      Append mode assumes some ordering on the values being sown defined by
      data-dependence.
    key: an optional JAX value that will be tied into the sown value.

  Returns:
    The original `value` that was passed in.
  """
  if key is not None:
    value = prim.tie_in(key, value)
  flat_args, in_tree = tree_util.tree_flatten(value)
  out_flat = sow_p.bind(*flat_args, name=name, tag=tag, mode=mode, tree=in_tree)
  return tree_util.tree_unflatten(in_tree, out_flat)
Beispiel #29
0
def l2_norm(tree):
    """Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
    leaves, _ = tree_flatten(tree)
    return np.sqrt(sum(np.vdot(x, x) for x in leaves))
Beispiel #30
0
 def tree_init(x0_tree):
     x0_flat, tree = tree_flatten(x0_tree)
     initial_states = [init(x0) for x0 in x0_flat]
     states_flat, subtrees = unzip2(map(tree_flatten, initial_states))
     packed_state = pack(map(pack, states_flat))
     return OptimizerState(packed_state, tree, subtrees)