Exemple #1
0
 def testTreedefTupleComparesEqual(self):
   # https://github.com/google/jax/issues/9066
   self.assertEqual(tree_util.tree_structure((3,)),
                    tree_util.treedef_tuple((tree_util.tree_structure(3),)))
Exemple #2
0
 def testTransposeMismatchOuter(self):
   tree = {"a": [1, 2], "b": [3, 4]}
   outer_treedef = tree_util.tree_structure({"a": 1, "b": 2, "c": 3})
   inner_treedef = tree_util.tree_structure([1, 2])
   with self.assertRaisesRegex(TypeError, "Mismatch"):
     tree_util.tree_transpose(outer_treedef, inner_treedef, tree)
Exemple #3
0
def unflatten_tree(tree, xs):
    """Inverse operation of `flatten_tree`."""
    return tree_util.tree_unflatten(tree_util.tree_structure(tree), xs)
Exemple #4
0
def jet(fun, primals, series):
    r"""Taylor-mode higher-order automatic differentiation.

  Args:
    fun: Function to be differentiated. Its arguments should be arrays, scalars,
      or standard Python containers of arrays or scalars. It should return an
      array, scalar, or standard Python container of arrays or scalars.
    primals: The primal values at which the Taylor approximation of ``fun`` should be
      evaluated. Should be either a tuple or a list of arguments,
      and its length should be equal to the number of positional parameters of
      ``fun``.
    series: Higher order Taylor-series-coefficients.
      Together, `primals` and `series` make up a truncated Taylor polynomial.
      Should be either a tuple or a list of tuples or lists,
      and its length dictates the degree of the truncated Taylor polynomial.

  Returns:
    A ``(primals_out, series_out)`` pair, where ``primals_out`` is ``fun(*primals)``,
    and together, ``primals_out`` and ``series_out`` are a
    truncated Taylor polynomial of :math:`f(h(\cdot))`.
    The ``primals_out`` value has the same Python tree structure as ``primals``,
    and the ``series_out`` value the same Python tree structure as ``series``.

  For example:

  >>> import jax
  >>> import jax.numpy as np

  Consider the function :math:`h(z) = z^3`, :math:`x = 0.5`,
  and the first few Taylor coefficients
  :math:`h_0=x^3`, :math:`h_1=3x^2`, and :math:`h_2=6x`.
  Let :math:`f(y) = \sin(y)`.

  >>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
  >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)

  :func:`jet` returns the Taylor coefficients of :math:`f(h(z)) = \sin(z^3)`
  according to Faà di Bruno's formula:

  >>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
  >>> print(f0,  f(h0))
  0.12467473 0.12467473

  >>> print(f1, df(h0) * h1)
  0.7441479 0.74414825

  >>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
  2.9064622 2.9064634
  """
    try:
        order, = set(map(len, series))
    except ValueError:
        msg = "jet terms have inconsistent lengths for different arguments"
        raise ValueError(msg) from None

    # TODO(mattjj): consider supporting pytree inputs
    for i, (x, terms) in enumerate(zip(primals, series)):
        treedef = tree_structure(x)
        if not treedef_is_leaf(treedef):
            raise ValueError(
                "primal value at position {} is not an array".format(i))
        for j, t in enumerate(terms):
            treedef = tree_structure(t)
            if not treedef_is_leaf(treedef):
                raise ValueError(
                    "term {} for argument {} is not an array".format(j, i))

    @lu.transformation_with_aux
    def flatten_fun_output(*args):
        ans = yield args, {}
        yield tree_flatten(ans)

    f, out_tree = flatten_fun_output(lu.wrap_init(fun))
    out_primals, out_terms = jet_fun(jet_subtrace(f),
                                     order).call_wrapped(primals, series)
    return tree_unflatten(out_tree(),
                          out_primals), tree_unflatten(out_tree(), out_terms)
Exemple #5
0
def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                    num_consts, num_carry, linear, unroll):
    """Reaps the body of a scan to pull out `clobber` and `append` sows."""

    const_tracers, carry_tracers, xs_tracers = jax_util.split_list(
        tracers, [num_consts, num_carry])
    _, carry_avals, xs_avals = tree_util.tree_map(
        lambda x: x.aval, (const_tracers, carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_util.tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]
    x_vals = [t.val for t in x_tracers]
    metadata = _get_harvest_metadata(jaxpr, settings,
                                     *(const_vals + carry_vals + x_vals))

    reap_modes = collections.defaultdict(set)
    reap_carry_avals = {}
    for name, meta in metadata.items():
        mode = meta['mode']
        aval = meta['aval']
        if mode == 'strict':
            raise ValueError(
                f'Cannot use strict mode for \'{name}\' inside `scan`.')
        reap_modes[mode].add(name)
        if mode == 'clobber':
            reap_carry_avals[name] = aval
    body_fun = jax_core.jaxpr_as_fun(jaxpr)

    reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals)

    reap_carry_in_tree = tree_util.tree_structure(
        ((carry_avals, reap_carry_avals), xs_avals))

    def new_body(carry, x):
        carry, _ = carry
        all_values = const_vals + tree_util.tree_leaves((carry, x))
        out, reaps = call_and_reap(body_fun,
                                   tag=settings.tag,
                                   allowlist=settings.allowlist,
                                   blocklist=settings.blocklist,
                                   exclusive=settings.exclusive)(*all_values)
        carry_out, y = jax_util.split_list(out, [num_carry])
        carry_reaps = {
            name: val
            for name, val in reaps.items() if name in reap_modes['clobber']
        }
        xs_reaps = {
            name: val
            for name, val in reaps.items() if name in reap_modes['append']
        }
        return (carry_out, carry_reaps), (y, xs_reaps)

    new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, reap_carry_in_tree,
        tuple(carry_avals + reap_carry_flat_avals + x_avals))
    dummy_reap_carry_vals = tree_util.tree_map(
        lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals)
    out = lax.scan_p.bind(
        *(consts + carry_vals + dummy_reap_carry_vals + xs_vals),
        reverse=reverse,
        length=length,
        jaxpr=new_body_jaxpr,
        num_consts=len(consts),
        num_carry=len(carry_vals + dummy_reap_carry_vals),
        linear=(linear[:len(consts)] + (False, ) * len(dummy_reap_carry_vals) +
                linear[len(consts):]),
        unroll=unroll)
    (carry_out,
     carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out)
    (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map(
        trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps)))
    for k, v in {**carry_reaps, **ys_reaps}.items():
        sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k)
    return carry_out + ys
Exemple #6
0
def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                     num_consts, num_carry, linear, unroll):
    """Injects values into a scan according to their sow mode."""

    const_tracers, carry_tracers, xs_tracers = jax_util.split_list(
        tracers, [num_consts, num_carry])
    carry_avals, xs_avals = tree_util.tree_map(lambda x: x.aval,
                                               (carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_util.tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]
    metadata = _get_harvest_metadata(
        jaxpr, settings, *(const_tracers + carry_tracers + x_tracers))

    plants = context.plants
    plant_modes = collections.defaultdict(set)
    plant_xs_avals = {}
    for name, meta in metadata.items():
        mode = meta['mode']
        aval = meta['aval']
        if mode == 'strict':
            raise ValueError(
                f'Cannot use strict mode for \'{name}\' inside `scan`.')
        plant_modes[mode].add(name)
        if mode == 'append' and name in plants:
            plant_xs_avals[name] = aval
    body_fun = jax_core.jaxpr_as_fun(jaxpr)
    clobber_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['clobber']
    }
    append_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['append']
    }

    plant_xs_flat_avals, _ = tree_util.tree_flatten(plant_xs_avals)

    plant_xs_in_tree = tree_util.tree_structure(
        (carry_avals, (xs_avals, plant_xs_avals)))

    def new_body(carry, x):
        x, plants = x
        all_plants = {**plants, **clobber_plants}
        all_values = const_vals + tree_util.tree_leaves((carry, x))
        out = plant(body_fun,
                    tag=settings.tag,
                    allowlist=settings.allowlist,
                    blocklist=settings.blocklist,
                    exclusive=settings.exclusive)(all_plants, *all_values)
        carry_out, y = jax_util.split_list(out, [num_carry])
        return carry_out, y

    new_body_jaxpr, consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, plant_xs_in_tree,
        tuple(carry_avals + x_avals + plant_xs_flat_avals))
    plant_vals = tree_util.tree_leaves(append_plants)
    out = lcf.scan_p.bind(*(consts + carry_vals + xs_vals + plant_vals),
                          reverse=reverse,
                          length=length,
                          jaxpr=new_body_jaxpr,
                          num_consts=len(consts),
                          num_carry=num_carry,
                          linear=linear + (False, ) * len(plant_vals),
                          unroll=unroll)
    return out
Exemple #7
0
 def testStringRepresentation(self, tree, correct_string):
   """Checks that the string representation of a tree works."""
   treedef = tree_util.tree_structure(tree)
   self.assertRegex(str(treedef), correct_string)
Exemple #8
0
 def testTreeDefWithEmptyDictStringRepresentation(self):
   self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
Exemple #9
0
 def testTreeDefWithEmptyDictStringRepresentation(self):
   if jax.lib._xla_extension_version < 35:
     self.skipTest("fixed in future jaxlib")
   self.assertEqual(str(tree_util.tree_structure({})), "PyTreeDef({})")
Exemple #10
0
def gmres(A,
          b,
          x0=None,
          *,
          tol=1e-5,
          atol=0.0,
          restart=20,
          maxiter=None,
          M=None,
          solve_method='batched'):
    """
  GMRES solves the linear system A x = b for x, given A and b.

  A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
  need not have any particular special properties, such as symmetry. However,
  convergence is often slow for nearly symmetric operators.

  Parameters
  ----------
  A: ndarray or function
      2D array or function that calculates the linear map (matrix-vector
      product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with
      the same structure and shape as its argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array, optional
      Starting guess for the solution. Must have the same structure as ``b``.
      If this is unspecified, zeroes are used.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
  restart : integer, optional
      Size of the Krylov subspace ("number of iterations") built between
      restarts. GMRES works by approximating the true solution x as its
      projection into a Krylov space of this dimension - this parameter
      therefore bounds the maximum accuracy achievable from any guess
      solution. Larger values increase both number of iterations and iteration
      cost, but may be necessary for convergence. The algorithm terminates
      early if convergence is achieved before the full subspace is built.
      Default is 20.
  maxiter : integer
      Maximum number of times to rebuild the size-``restart`` Krylov space
      starting from the solution found at the last iteration. If GMRES
      halts or is very slow, decreasing this parameter may help.
      Default is infinite.
  M : ndarray or function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.
  solve_method : 'incremental' or 'batched'
      The 'incremental' solve method builds a QR decomposition for the Krylov
      subspace incrementally during the GMRES process using Givens rotations.
      This improves numerical stability and gives a free estimate of the
      residual norm that allows for early termination within a single "restart".
      In contrast, the 'batched' solve method solves the least squares problem
      from scratch at the end of each GMRES iteration. It does not allow for
      early termination, but has much less overhead on GPUs.

  See also
  --------
  scipy.sparse.linalg.gmres
  jax.lax.custom_linear_solve
  """

    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)
    if M is None:
        M = _identity
    A = _normalize_matvec(A)
    M = _normalize_matvec(M)

    b, x0 = device_put((b, x0))
    size = sum(bi.size for bi in tree_leaves(b))

    if maxiter is None:
        maxiter = 10 * size  # copied from scipy
    restart = min(restart, size)

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    b_norm = _norm(b)
    atol = jnp.maximum(tol * b_norm, atol)

    Mb = M(b)
    Mb_norm = _norm(Mb)
    ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm)

    if solve_method == 'incremental':
        gmres_func = _gmres_incremental
    elif solve_method == 'batched':
        gmres_func = _gmres_batched
    else:
        raise ValueError(
            f"invalid solve_method {solve_method}, must be either "
            "'incremental' or 'batched'")

    def _solve(A, b):
        return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M,
                            gmres_func)

    x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)

    failed = jnp.isnan(_norm(x))
    info = jnp.where(failed, x=-1, y=0)
    return x, info
Exemple #11
0
 def in_tree(self):
     """``PyTreeDef`` of the pair (positional arguments, keyword arguments)."""
     return tree_util.tree_structure(self.args_info)
Exemple #12
0
def tree_fill_like(x, tree):
    return tree_fill(x, tree_structure(tree))
Exemple #13
0
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
    """Use Conjugate Gradient iteration to solve ``Ax = b``.

  The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
  numerical precision), but note that the interface is slightly different: you
  need to supply the linear operator ``A`` as a function instead of a sparse
  matrix or ``LinearOperator``.

  Derivatives of ``cg`` are implemented via implicit differentiation with
  another ``cg`` solve, rather than by differentiating *through* the solver.
  They will be accurate only if both solves converge.

  Parameters
  ----------
  A : function
      Function that calculates the matrix-vector product ``Ax`` when called
      like ``A(x)``. ``A`` must represent a hermitian, positive definite
      matrix, and must return array(s) with the same structure and shape as its
      argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array
      Starting guess for the solution. Must have the same structure as ``b``.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
  maxiter : integer
      Maximum number of iterations.  Iteration will stop after maxiter
      steps even if the specified tolerance has not been achieved.
  M : function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.

  See also
  --------
  scipy.sparse.linalg.cg
  jax.lax.custom_linear_solve
  """
    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)

    b, x0 = device_put((b, x0))

    if maxiter is None:
        size = sum(bi.size for bi in tree_leaves(b))
        maxiter = 10 * size  # copied from scipy

    if M is None:
        M = _identity

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    if _shapes(x0) != _shapes(b):
        raise ValueError('arrays in x0 and b must have matching shapes: '
                         f'{_shapes(x0)} vs {_shapes(b)}')

    cg_solve = partial(_cg_solve,
                       x0=x0,
                       tol=tol,
                       atol=atol,
                       maxiter=maxiter,
                       M=M)

    # real-valued positive-definite linear operators are symmetric
    def real_valued(x):
        return not issubclass(x.dtype.type, np.complexfloating)

    symmetric = all(map(real_valued, tree_leaves(b)))
    x = lax.custom_linear_solve(A,
                                b,
                                solve=cg_solve,
                                transpose_solve=cg_solve,
                                symmetric=symmetric)
    info = None  # TODO(shoyer): return the real iteration count here
    return x, info
Exemple #14
0
 def to_jvp(*primals):
   outs, out_batched = call_rule(rule, axis_size, mutually_batched, primals)
   check_vmap_rule_trees(
       rule, tree_structure(outs), tree_structure(out_batched))
   out_mutually_batched.store(out_batched)
   return outs
Exemple #15
0
def _gmres(A,
           b,
           x0=None,
           *,
           tol=1e-5,
           atol=0.0,
           restart=20,
           maxiter=None,
           M=None,
           qr_mode=False):
    """
  GMRES solves the linear system A x = b for x, given A and b.

  A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
  need not have any particular special properties, such as symmetry. However,
  convergence is often slow for nearly symmetric operators.

  Parameters
  ----------
  A: function
     Function that calculates the linear map (matrix-vector product)
     ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same
     structure and shape as its argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array, optional
       Starting guess for the solution. Must have the same structure as ``b``.
       If this is unspecified, zeroes are used.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
  restart : integer, optional
      Size of the Krylov subspace ("number of iterations") built between
      restarts. GMRES works by approximating the true solution x as its
      projection into a Krylov space of this dimension - this parameter
      therefore bounds the maximum accuracy achievable from any guess
      solution. Larger values increase both number of iterations and iteration
      cost, but may be necessary for convergence. The algorithm terminates
      early if convergence is achieved before the full subspace is built.
      Default is 20.
  maxiter : integer
      Maximum number of times to rebuild the size-``restart`` Krylov space
      starting from the solution found at the last iteration. If GMRES
      halts or is very slow, decreasing this parameter may help.
      Default is infinite.
  M : function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.
  qr_mode : bool
      If True, the algorithm builds an internal Krylov subspace using a QR
      based algorithm, which reduces overhead and improved stability. However,
      it may degrade performance significantly on GPUs or TPUs, in which case
      this flag should be set False.

  See also
  --------
  scipy.sparse.linalg.gmres
  jax.lax.custom_linear_solve
  """

    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)
    if M is None:
        M = _identity

    b, x0 = device_put((b, x0))
    size = sum(bi.size for bi in tree_leaves(b))

    if maxiter is None:
        maxiter = 10 * size  # copied from scipy
    restart = min(restart, size)

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    b_norm = _norm_tree(b)
    if b_norm == 0:
        return b, 0
    outer_tol = jnp.maximum(tol * b_norm, atol)

    Mb = M(b)
    Mb_norm = _norm_tree(Mb)
    inner_tol = Mb_norm * min(1.0, outer_tol / b_norm)

    if qr_mode:

        def _solve(A, b):
            return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart,
                                maxiter, M, _gmres_plain)
    else:

        def _solve(A, b):
            return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart,
                                maxiter, M, _gmres_qr)

    x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)

    failed = jnp.isnan(_norm_tree(x))
    info = jnp.where(failed, x=-1, y=0)
    return x, info