Ejemplo n.º 1
0
def global_norm(updates: Updates) -> Updates:
    return jnp.sqrt(sum([jnp.sum(jnp.square(x))
                         for x in tree_leaves(updates)]))
Ejemplo n.º 2
0
def _vdot_tree(x, y):
  return sum(tree_leaves(tree_map(partial(
    jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))
Ejemplo n.º 3
0
def _norm(x):
  xs = tree_leaves(x)
  return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
Ejemplo n.º 4
0
 def testAllLeavesWithTrees(self, tree):
     leaves = tree_util.tree_leaves(tree)
     self.assertTrue(tree_util.all_leaves(leaves))
     self.assertFalse(tree_util.all_leaves([tree]))
Ejemplo n.º 5
0
def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
  """
  Orthogonalize x against the columns of Q. The process is repeated
  up to `max_iterations` times, or fewer if the condition
  ||r|| < (1/sqrt(2)) ||x|| is met earlier (see below for the meaning
  of r and x).

  Parameters
  ----------
  Q : array or tree of arrays
      A matrix of orthonormal columns.
  x : array or tree of arrays
      A vector. It will be replaced with a new vector q which is orthonormal
      to the columns of Q, such that x in span(col(Q), q).
  xnorm : float
      Norm of x.

  Returns
  -------
  q : array or tree of arrays
      A unit vector, orthonormal to each column of Q, such that
      x in span(col(Q), q).
  r : array
      Stores the overlaps of x with each vector in Q.
  """
  # "twice is enough"
  # http://slepc.upv.es/documentation/reports/str1.pdf

  # TODO(shoyer): consider switching to only one iteration, like SciPy?

  # This assumes that Q's leaves all have the same dimension in the last
  # axis.
  r = jnp.zeros((tree_leaves(Q)[0].shape[-1]))
  q = x
  xnorm_scaled = xnorm / jnp.sqrt(2)

  def body_function(carry):
    k, q, r, qnorm_scaled = carry
    h = _project_on_columns(Q, q)
    Qh = tree_map(lambda X: _dot(X, h), Q)
    q = _sub(q, Qh)
    r = _add(r, h)

    def qnorm_cond(carry):
      k, not_done, _, _ = carry
      return jnp.logical_and(not_done, k < (max_iterations - 1))

    def qnorm(carry):
      k, _, q, qnorm_scaled = carry
      _, qnorm = _safe_normalize(q)
      qnorm_scaled = qnorm / jnp.sqrt(2)
      return (k, False, q, qnorm_scaled)

    init = (k, True, q, qnorm_scaled)
    _, _, q, qnorm_scaled = lax.while_loop(qnorm_cond, qnorm, init)
    return (k + 1, q, r, qnorm_scaled)

  def cond_function(carry):
    k, _, r, qnorm_scaled = carry
    _, rnorm = _safe_normalize(r)
    return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)

  k, q, r, qnorm_scaled = body_function((0, q, r, xnorm_scaled))
  k, q, r, _ = lax.while_loop(cond_function, body_function,
                              (k, q, r, qnorm_scaled))
  return q, r
Ejemplo n.º 6
0
def _infer_shape_jax(f, *vals, **params):
    avals = map(abstractify, vals)
    return pe.abstract_eval_fun(
        lambda *a, **k: tree_util.tree_leaves(f(*a, **k)), *avals, **params)
Ejemplo n.º 7
0
def flatten(structure, expand_composites=False):
  """Add expand_composites support for JAX."""
  if expand_composites and JAX_MODE:
    from jax import tree_util  # pylint: disable=g-import-not-at-top
    return tree_util.tree_leaves(structure)
  return dm_flatten(structure)
Ejemplo n.º 8
0
def tree_size(tree: PyTree) -> int:
    """
    Returns the sum of the size of all leaves in the tree.
    It's equivalent to the number of scalars in the pytree.
    """
    return sum(tree_leaves(tree_map(lambda x: x.size, tree)))
Ejemplo n.º 9
0
 def donate_argnums(self):
     """Flat tuple of donated argument indices."""
     return tuple(
         i for i, x in enumerate(tree_util.tree_leaves(self.args_info))
         if x.donated)
Ejemplo n.º 10
0
    def _ApplyGraphNet(graph):
        """Applies a configured GraphNetwork to a graph.

    This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261

    There is one difference. For the nodes update the class aggregates over the
    sender edges and receiver edges separately. This is a bit more general
    the algorithm described in the paper. The original behaviour can be
    recovered by using only the receiver edge aggregations for the update.

    In addition this implementation supports softmax attention over incoming
    edge features.

    Many popular Graph Neural Networks can be implemented as special cases of
    GraphNets, for more information please see the paper.

    Args:
      graph: a `GraphsTuple` containing the graph.

    Returns:
      Updated `GraphsTuple`.
    """
        # pylint: disable=g-long-lambda
        nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
        # Equivalent to jnp.sum(n_node), but jittable
        sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
        sum_n_edge = senders.shape[0]
        if not tree.tree_all(
                tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
            raise ValueError(
                'All node arrays in nest must contain the same number of nodes.'
            )

        sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
        received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
        # Here we scatter the global features to the corresponding edges,
        # giving us tensors of shape [num_edges, global_feat].
        global_edge_attributes = tree.tree_map(
            lambda g: jnp.repeat(
                g, n_edge, axis=0, total_repeat_length=sum_n_edge), globals_)

        if update_edge_fn:
            edges = update_edge_fn(edges, sent_attributes, received_attributes,
                                   global_edge_attributes)

        if attention_logit_fn:
            logits = attention_logit_fn(edges, sent_attributes,
                                        received_attributes,
                                        global_edge_attributes)
            tree_calculate_weights = functools.partial(attention_normalize_fn,
                                                       segment_ids=receivers,
                                                       num_segments=sum_n_node)
            weights = tree.tree_map(tree_calculate_weights, logits)
            edges = attention_reduce_fn(edges, weights)

        if update_node_fn:
            sent_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, senders, sum_n_node),
                edges)
            received_attributes = tree.tree_map(
                lambda e: aggregate_edges_for_nodes_fn(e, receivers, sum_n_node
                                                       ), edges)
            # Here we scatter the global features to the corresponding nodes,
            # giving us tensors of shape [num_nodes, global_feat].
            global_attributes = tree.tree_map(
                lambda g: jnp.repeat(
                    g, n_node, axis=0, total_repeat_length=sum_n_node),
                globals_)
            nodes = update_node_fn(nodes, sent_attributes, received_attributes,
                                   global_attributes)

        if update_global_fn:
            n_graph = n_node.shape[0]
            graph_idx = jnp.arange(n_graph)
            # To aggregate nodes and edges from each graph to global features,
            # we first construct tensors that map the node to the corresponding graph.
            # For example, if you have `n_node=[1,2]`, we construct the tensor
            # [0, 1, 1]. We then do the same for edges.
            node_gr_idx = jnp.repeat(graph_idx,
                                     n_node,
                                     axis=0,
                                     total_repeat_length=sum_n_node)
            edge_gr_idx = jnp.repeat(graph_idx,
                                     n_edge,
                                     axis=0,
                                     total_repeat_length=sum_n_edge)
            # We use the aggregation function to pool the nodes/edges per graph.
            node_attributes = tree.tree_map(
                lambda n: aggregate_nodes_for_globals_fn(
                    n, node_gr_idx, n_graph), nodes)
            edge_attribtutes = tree.tree_map(
                lambda e: aggregate_edges_for_globals_fn(
                    e, edge_gr_idx, n_graph), edges)
            # These pooled nodes are the inputs to the global update fn.
            globals_ = update_global_fn(node_attributes, edge_attribtutes,
                                        globals_)
        # pylint: enable=g-long-lambda
        return gn_graph.GraphsTuple(nodes=nodes,
                                    edges=edges,
                                    receivers=receivers,
                                    senders=senders,
                                    globals=globals_,
                                    n_node=n_node,
                                    n_edge=n_edge)
Ejemplo n.º 11
0
def _vdot(x, y):
    f = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
    return sum(tree_leaves(tree_multimap(f, x, y)))
Ejemplo n.º 12
0
 def new_body(carry, x):
     flat_args = tree_leaves((carry, x))
     out = body_fun(*(const_vals + flat_args))
     out_carry, y = split_list(out, [num_carry])
     return out_carry, y
Ejemplo n.º 13
0
 def test_var_tree_flatten(self):
     newsym = core.gensym()
     aval = core.ShapedArray((), np.dtype('int32'))
     a, b, c, d = (newsym(aval), newsym(aval), newsym(aval), newsym(aval))
     syms = {c: d, a: b}
     assert 'bd' == ''.join(map(str, tree_leaves(syms)))
Ejemplo n.º 14
0
def debug_callback_impl(*flat_args, callback: Callable[..., Any],
                        effect: DebugEffect, in_tree: tree_util.PyTreeDef):
    del effect
    args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
    out = callback(*args, **kwargs)
    return tree_util.tree_leaves(out)
Ejemplo n.º 15
0
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
    """Use Conjugate Gradient iteration to solve ``Ax = b``.

  The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
  numerical precision), but note that the interface is slightly different: you
  need to supply the linear operator ``A`` as a function instead of a sparse
  matrix or ``LinearOperator``.

  Derivatives of ``cg`` are implemented via implicit differentiation with
  another ``cg`` solve, rather than by differentiating *through* the solver.
  They will be accurate only if both solves converge.

  Parameters
  ----------
  A : function
      Function that calculates the matrix-vector product ``Ax`` when called
      like ``A(x)``. ``A`` must represent a hermitian, positive definite
      matrix, and must return array(s) with the same structure and shape as its
      argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array
      Starting guess for the solution. Must have the same structure as ``b``.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
  maxiter : integer
      Maximum number of iterations.  Iteration will stop after maxiter
      steps even if the specified tolerance has not been achieved.
  M : function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.

  See also
  --------
  scipy.sparse.linalg.cg
  jax.lax.custom_linear_solve
  """
    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)

    b, x0 = device_put((b, x0))

    if maxiter is None:
        size = sum(bi.size for bi in tree_leaves(b))
        maxiter = 10 * size  # copied from scipy

    if M is None:
        M = _identity

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    if _shapes(x0) != _shapes(b):
        raise ValueError('arrays in x0 and b must have matching shapes: '
                         f'{_shapes(x0)} vs {_shapes(b)}')

    cg_solve = partial(_cg_solve,
                       x0=x0,
                       tol=tol,
                       atol=atol,
                       maxiter=maxiter,
                       M=M)

    # real-valued positive-definite linear operators are symmetric
    def real_valued(x):
        return not issubclass(x.dtype.type, np.complexfloating)

    symmetric = all(map(real_valued, tree_leaves(b)))
    x = lax.custom_linear_solve(A,
                                b,
                                solve=cg_solve,
                                transpose_solve=cg_solve,
                                symmetric=symmetric)
    info = None  # TODO(shoyer): return the real iteration count here
    return x, info
Ejemplo n.º 16
0
def wait_until_computed(x):
    for leaf in tree_leaves(x):
        leaf.block_until_ready()
Ejemplo n.º 17
0
 def test_var_tree_flatten(self):
     newsym = core.gensym()
     a, b, c, d = (newsym(core.abstract_unit), newsym(core.abstract_unit),
                   newsym(core.abstract_unit), newsym(core.abstract_unit))
     syms = {c: d, a: b}
     assert 'bd' == ''.join(map(str, tree_leaves(syms)))
Ejemplo n.º 18
0
def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree,
                           kwargs, **params):
    """Batching rule for layer_cau primitive to handle custom layers."""
    if all(dim is batching.not_mapped for dim in dims):
        return layer_cau_p.bind(*vals,
                                num_consts=num_consts,
                                in_tree=in_tree,
                                out_tree=out_tree,
                                kwargs=kwargs,
                                **params)
    orig_vals, orig_dims = vals, dims
    vals, dims = vals[num_consts:], dims[num_consts:]
    args = tree_util.tree_unflatten(in_tree, vals)
    dims_ = [not_mapped if dim is None else dim for dim in dims]
    layer, args = args[0], args[1:]
    if hasattr(layer, '_call_and_update_batched'):
        num_params = len(tree_util.tree_leaves(layer))
        layer_dims, arg_dims = dims_[:num_params], dims_[num_params:]
        if kwargs['has_rng']:
            rng, args = args[0], args[1:]
            rng_dim, arg_dims = arg_dims[0], arg_dims[1:]
        mapping_over_layer = all(layer_dim is not not_mapped
                                 for layer_dim in layer_dims)
        mapping_over_args = all(arg_dim is not not_mapped
                                for arg_dim in arg_dims)
        assert mapping_over_layer or mapping_over_args, (layer_dims, arg_dims)
        if not mapping_over_layer and mapping_over_args:
            if kwargs['has_rng']:
                if rng_dim is not not_mapped:
                    arg_dims = tuple(None if dim is not_mapped else dim
                                     for dim in arg_dims)
                    map_fun = jax.vmap(
                        lambda layer, rng, *args: _layer_cau_batched(
                            layer,
                            rng,
                            *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                            **kwargs),
                        in_axes=(None, rng_dim) + (None, ) * len(arg_dims))
                else:
                    map_fun = lambda layer, *args: _layer_cau_batched(
                        layer,
                        *args,  # pylint: disable=unnecessary-lambda, g-long-lambda
                        **kwargs)
                vals_out, update_out = map_fun(layer, rng, *args)
            else:
                vals_out, update_out = _layer_cau_batched(
                    layer, *args, **kwargs)
            vals_out = tree_util.tree_leaves(vals_out)
            update_out = tree_util.tree_leaves(update_out)
            assert all(dim == 0 for dim in arg_dims)
            # Assume dimensions out are consistent
            dims_out = (0, ) * len(vals_out)
            dims_update = (None, ) * len(update_out)
            assert len(vals_out) == len(dims_out)
            assert len(update_out) == len(dims_update)
            return vals_out + update_out, dims_out + dims_update
    batched, out_dims = primitive.batch_fun(
        lu.wrap_init(
            layer_cau_p.impl,
            dict(params,
                 num_consts=num_consts,
                 in_tree=in_tree,
                 out_tree=out_tree,
                 kwargs=kwargs)), orig_dims)
    return batched.call_wrapped(*orig_vals), out_dims()
Ejemplo n.º 19
0
 def wrapped(*args):
     mapped_args = mapping_fn(*args)
     ildjs = inverse.ildj(mapping_fn, *args)(mapped_args)
     return target_log_prob(mapped_args) - np.sum(
         np.array(tree_util.tree_leaves(ildjs)))
Ejemplo n.º 20
0
def global_norm(items):
    return jnp.sqrt(jnp.sum([jnp.sum(x**2) for x in tree_leaves(items)]))
Ejemplo n.º 21
0
 def transpose(res_arg, ct_out):
     args_flat = tree_leaves((res_arg, ct_out))
     ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts,
                                                 *args_flat)
     return tree_unflatten(lin_tree, ct_ins)
Ejemplo n.º 22
0
def _gmres(A,
           b,
           x0=None,
           *,
           tol=1e-5,
           atol=0.0,
           restart=20,
           maxiter=None,
           M=None,
           qr_mode=False):
    """
  GMRES solves the linear system A x = b for x, given A and b.

  A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
  need not have any particular special properties, such as symmetry. However,
  convergence is often slow for nearly symmetric operators.

  Parameters
  ----------
  A: function
     Function that calculates the linear map (matrix-vector product)
     ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with the same
     structure and shape as its argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array, optional
       Starting guess for the solution. Must have the same structure as ``b``.
       If this is unspecified, zeroes are used.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
  restart : integer, optional
      Size of the Krylov subspace ("number of iterations") built between
      restarts. GMRES works by approximating the true solution x as its
      projection into a Krylov space of this dimension - this parameter
      therefore bounds the maximum accuracy achievable from any guess
      solution. Larger values increase both number of iterations and iteration
      cost, but may be necessary for convergence. The algorithm terminates
      early if convergence is achieved before the full subspace is built.
      Default is 20.
  maxiter : integer
      Maximum number of times to rebuild the size-``restart`` Krylov space
      starting from the solution found at the last iteration. If GMRES
      halts or is very slow, decreasing this parameter may help.
      Default is infinite.
  M : function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.
  qr_mode : bool
      If True, the algorithm builds an internal Krylov subspace using a QR
      based algorithm, which reduces overhead and improved stability. However,
      it may degrade performance significantly on GPUs or TPUs, in which case
      this flag should be set False.

  See also
  --------
  scipy.sparse.linalg.gmres
  jax.lax.custom_linear_solve
  """

    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)
    if M is None:
        M = _identity

    b, x0 = device_put((b, x0))
    size = sum(bi.size for bi in tree_leaves(b))

    if maxiter is None:
        maxiter = 10 * size  # copied from scipy
    restart = min(restart, size)

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    b_norm = _norm_tree(b)
    if b_norm == 0:
        return b, 0
    outer_tol = jnp.maximum(tol * b_norm, atol)

    Mb = M(b)
    Mb_norm = _norm_tree(Mb)
    inner_tol = Mb_norm * min(1.0, outer_tol / b_norm)

    if qr_mode:

        def _solve(A, b):
            return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart,
                                maxiter, M, _gmres_plain)
    else:

        def _solve(A, b):
            return _gmres_solve(A, b, x0, outer_tol, inner_tol, restart,
                                maxiter, M, _gmres_qr)

    x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)

    failed = jnp.isnan(_norm_tree(x))
    info = jnp.where(failed, x=-1, y=0)
    return x, info
Ejemplo n.º 23
0
def _shapes(pytree):
  return map(jnp.shape, tree_leaves(pytree))
Ejemplo n.º 24
0
def _vdot_tree(x, y):
    return sum(tree_leaves(tree_multimap(_vdot, x, y)))
Ejemplo n.º 25
0
def _vdot_real_tree(x, y):
  return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))
Ejemplo n.º 26
0
 def _cau_jaxpr(self, *args, **kwargs):
     flat_args = tree_util.tree_leaves(args)
     out_flat = eval_jaxpr_with_kwargs(self._jaxpr.jaxpr,
                                       self._jaxpr.literals, *flat_args,
                                       **kwargs)
     return tree_util.tree_unflatten(self._out_tree, out_flat)
Ejemplo n.º 27
0
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
          M=None, solve_method='batched'):
  """
  GMRES solves the linear system A x = b for x, given A and b.

  A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
  need not have any particular special properties, such as symmetry. However,
  convergence is often slow for nearly symmetric operators.

  Parameters
  ----------
  A: ndarray or function
      2D array or function that calculates the linear map (matrix-vector
      product) ``Ax`` when called like ``A(x)``. ``A`` must return array(s) with
      the same structure and shape as its argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array or tree of arrays, optional
      Starting guess for the solution. Must have the same structure as ``b``.
      If this is unspecified, zeroes are used.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
  restart : integer, optional
      Size of the Krylov subspace ("number of iterations") built between
      restarts. GMRES works by approximating the true solution x as its
      projection into a Krylov space of this dimension - this parameter
      therefore bounds the maximum accuracy achievable from any guess
      solution. Larger values increase both number of iterations and iteration
      cost, but may be necessary for convergence. The algorithm terminates
      early if convergence is achieved before the full subspace is built.
      Default is 20.
  maxiter : integer
      Maximum number of times to rebuild the size-``restart`` Krylov space
      starting from the solution found at the last iteration. If GMRES
      halts or is very slow, decreasing this parameter may help.
      Default is infinite.
  M : ndarray or function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.
  solve_method : 'incremental' or 'batched'
      The 'incremental' solve method builds a QR decomposition for the Krylov
      subspace incrementally during the GMRES process using Givens rotations.
      This improves numerical stability and gives a free estimate of the
      residual norm that allows for early termination within a single "restart".
      In contrast, the 'batched' solve method solves the least squares problem
      from scratch at the end of each GMRES iteration. It does not allow for
      early termination, but has much less overhead on GPUs.

  See also
  --------
  scipy.sparse.linalg.gmres
  jax.lax.custom_linear_solve
  """

  if x0 is None:
    x0 = tree_map(jnp.zeros_like, b)
  if M is None:
    M = _identity
  A = _normalize_matvec(A)
  M = _normalize_matvec(M)

  b, x0 = device_put((b, x0))
  size = sum(bi.size for bi in tree_leaves(b))

  if maxiter is None:
    maxiter = 10 * size  # copied from scipy
  restart = min(restart, size)

  if tree_structure(x0) != tree_structure(b):
    raise ValueError(
        'x0 and b must have matching tree structure: '
        f'{tree_structure(x0)} vs {tree_structure(b)}')

  b_norm = _norm(b)
  atol = jnp.maximum(tol * b_norm, atol)

  Mb = M(b)
  Mb_norm = _norm(Mb)
  ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm)

  if solve_method == 'incremental':
    gmres_func = _gmres_incremental
  elif solve_method == 'batched':
    gmres_func = _gmres_batched
  else:
    raise ValueError(f"invalid solve_method {solve_method}, must be either "
                     "'incremental' or 'batched'")

  def _solve(A, b):
    return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func)
  x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)

  failed = jnp.isnan(_norm(x))
  info = jnp.where(failed, x=-1, y=0)
  return x, info
Ejemplo n.º 28
0
    def handle_call_primitive(self, call_primitive, f, tracers, params,
                              is_map):
        """Handler for call_primitives, like jit or layer_call.

    When an UnzipTracer hits a call primitive, there is either a variable
    inside of the call primitive, in which case the input
    function needs to be unzipped into two, or there are no variables
    in the function, so the call_primitive is recorded in the trace as-is.

    We use `unzip_eval_wrapper`, which returns whether or not an unzip
    was successful or not. If it was successful, we record two new
    Jaxprs into the trace (one for init, one for apply). Otherwise, we
    just record the Jaxpr corresponding to the function call.

    Args:
      call_primitive: a call primitive like xla_call
      f: a jax.linear_util wrapped function to be called
      tracers: inputs to the function
      params: parameters of the primitives
      is_map: whether or not the primitive is a map primitive (e.g. xla_pmap)

    Returns:
      A list of output tracers
    """
        name = params.get('name', f.__name__)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const_abstracted, tracers)
        if call_primitive in current_custom_rules():
            return current_custom_rules()[call_primitive](self, f, *tracers,
                                                          **params)
        if call_primitive in pe.call_partial_eval_rules:
            raise NotImplementedError
        in_pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
        if is_map:
            pvs = [
                None if pv is None else mapped_aval(params['axis_size'], pv)
                for pv in in_pvs
            ]
        else:
            pvs = in_pvs
        keys = tuple(t.is_key() for t in tracers)
        new_settings = UnzipSettings(settings.tag, call_primitive
                                     in block_registry)
        fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
        out_flat = call_primitive.bind(fun, *in_consts, **params)
        success, results = aux()
        if not success:
            out_pvs, out_keys, jaxpr, env = results
            out_pv_consts, consts = jax_util.split_list(
                out_flat, [len(out_pvs)])
            out_tracers = self._bound_output_tracers(call_primitive, params,
                                                     jaxpr, consts, env,
                                                     tracers, out_pvs,
                                                     out_pv_consts, out_keys,
                                                     name, is_map)
            return out_tracers
        init_name = jax_util.wrap_name(name, 'init')
        apply_name = jax_util.wrap_name(name, 'apply')
        init_pvs, num_init_consts, apply_pvs = results[0]
        init_jaxpr, apply_jaxpr = results[1]
        init_env, apply_env = results[2]
        variable_names, variable_tree, apply_keys = results[3]

        key_tracers = [t for t in tracers if t.is_key()]
        abstract_tracers = [t for t in tracers if not t.is_key()]
        all_init_consts, all_apply_consts = jax_util.split_list(
            out_flat, [len(init_pvs) + num_init_consts])
        init_pv_consts, init_consts = jax_util.split_list(
            all_init_consts, [len(init_pvs)])
        apply_pv_consts, apply_consts = jax_util.split_list(
            all_apply_consts, [len(apply_pvs)])

        variable_tracers = self._bound_output_tracers(call_primitive, params,
                                                      init_jaxpr, init_consts,
                                                      init_env, key_tracers,
                                                      init_pvs, init_pv_consts,
                                                      [True] * len(init_pvs),
                                                      init_name, is_map)

        unflat_variables = tree_util.tree_unflatten(variable_tree,
                                                    variable_tracers)
        if call_primitive is harvest.nest_p:
            variable_dict = harvest.sow(dict(
                safe_zip(variable_names, unflat_variables)),
                                        tag=settings.tag,
                                        name=params['scope'],
                                        mode='strict')
            unflat_variables = tuple(variable_dict[name]
                                     for name in variable_names)
        else:
            unflat_variables = [
                harvest.sow(  # pylint: disable=g-complex-comprehension
                    unflat_variable,
                    tag=settings.tag,
                    name=name,
                    mode='strict') for unflat_variable, name in safe_zip(
                        unflat_variables, variable_names)
            ]
        variable_tracers = tree_util.tree_leaves(unflat_variables)

        out_tracers = self._bound_output_tracers(
            call_primitive, params, apply_jaxpr, apply_consts, apply_env,
            variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
            apply_keys, apply_name, is_map)
        return out_tracers
Ejemplo n.º 29
0
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                       num_consts, num_carry, linear):
    """Collects and injects values into/from the scan body."""
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    values = [t.val for t in tracers]
    consts, init, xs = jax_util.split_list(values, [num_consts, num_carry])

    active_sows = _find_sows(jaxpr, settings.tag)
    active_modes = [params['mode'] for params in active_sows]
    if any(mode == 'strict' for mode in active_modes):
        raise ValueError('Cannot use strict mode in a scan.')
    active_names = [params['name'] for params in active_sows]
    sow_modes = {name: mode for name, mode in zip(active_names, active_modes)}
    carry_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'clobber'
    }
    xs_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'append'
    }

    def jaxpr_fun(carry, x):
        body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, [], *(consts + carry + x))
        carry, y = jax_util.split_list(body_out, [num_carry])
        return carry, y

    harvest_body = harvest(jaxpr_fun,
                           tag=settings.tag,
                           allowlist=settings.allowlist,
                           blocklist=settings.blocklist)

    def body(carry, x):
        x_plants, x_vals = x
        (carry, y), reaps = harvest_body({
            **carry_plants,
            **x_plants
        }, carry, x_vals)
        return carry, (y, reaps)

    xs_flat = tree_util.tree_leaves((xs_plants, xs))
    x_avals = []
    for x in xs_flat:
        x_aval = jax_core.get_aval(x)
        if x_aval is jax_core.abstract_unit:
            x_avals.append(x_aval)
        else:
            x_shape, x_dtype = masking.padded_shape_as_value(
                x.shape[1:]), x.dtype
            x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype))
    x_avals = tuple(x_avals)
    init_avals = tuple(
        abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init)
    in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs)))
    body_jaxpr, new_consts, out_tree = (
        jax.lax.lax_control_flow._initial_style_jaxpr(  # pylint: disable=protected-access
            body, in_tree, init_avals + x_avals))
    new_values = list(new_consts) + in_flat
    num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts)
    remaining_linear = linear[num_consts:]
    new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] +
                  (False, ) * num_xs_plants + remaining_linear[len(init):])
    assert len(new_linear) == len(new_values)

    outs = lax.scan_p.bind(*new_values,
                           length=length,
                           reverse=reverse,
                           jaxpr=body_jaxpr,
                           num_consts=len(new_consts),
                           num_carry=num_carry,
                           linear=new_linear)
    outs = safe_map(trace.pure, outs)
    carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs)
    out_reaps = {}
    for k, val in reaps.items():
        mode = sow_modes.get(k, 'strict')
        if mode == 'append':
            val = tree_util.tree_map(np.concatenate, val)
        elif mode == 'clobber':
            val = tree_util.tree_map(lambda x: x[-1], val)
        out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict')
    (carry, ys) = prim.tie_in(out_reaps, (carry, ys))
    return carry + ys
Ejemplo n.º 30
0
def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                     num_consts, num_carry, linear, unroll):
    """Injects values into a scan according to their sow mode."""

    const_tracers, carry_tracers, xs_tracers = jax_util.split_list(
        tracers, [num_consts, num_carry])
    carry_avals, xs_avals = tree_util.tree_map(lambda x: x.aval,
                                               (carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_util.tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]
    metadata = _get_harvest_metadata(
        jaxpr, settings, *(const_tracers + carry_tracers + x_tracers))

    plants = context.plants
    plant_modes = collections.defaultdict(set)
    plant_xs_avals = {}
    for name, meta in metadata.items():
        mode = meta['mode']
        aval = meta['aval']
        if mode == 'strict':
            raise ValueError(
                f'Cannot use strict mode for \'{name}\' inside `scan`.')
        plant_modes[mode].add(name)
        if mode == 'append' and name in plants:
            plant_xs_avals[name] = aval
    body_fun = jax_core.jaxpr_as_fun(jaxpr)
    clobber_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['clobber']
    }
    append_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['append']
    }

    plant_xs_flat_avals, _ = tree_util.tree_flatten(plant_xs_avals)

    plant_xs_in_tree = tree_util.tree_structure(
        (carry_avals, (xs_avals, plant_xs_avals)))

    def new_body(carry, x):
        x, plants = x
        all_plants = {**plants, **clobber_plants}
        all_values = const_vals + tree_util.tree_leaves((carry, x))
        out = plant(body_fun,
                    tag=settings.tag,
                    allowlist=settings.allowlist,
                    blocklist=settings.blocklist,
                    exclusive=settings.exclusive)(all_plants, *all_values)
        carry_out, y = jax_util.split_list(out, [num_carry])
        return carry_out, y

    new_body_jaxpr, consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, plant_xs_in_tree,
        tuple(carry_avals + x_avals + plant_xs_flat_avals))
    plant_vals = tree_util.tree_leaves(append_plants)
    out = lcf.scan_p.bind(*(consts + carry_vals + xs_vals + plant_vals),
                          reverse=reverse,
                          length=length,
                          jaxpr=new_body_jaxpr,
                          num_consts=len(consts),
                          num_carry=num_carry,
                          linear=linear + (False, ) * len(plant_vals),
                          unroll=unroll)
    return out