Ejemplo n.º 1
0
    def wrapped(*args):
        out = func(*args)
        out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out])

        if expected_output_core_dims is None:
            output_core_dims = [()] * len(out_shapes)
        else:
            output_core_dims = expected_output_core_dims
            if len(output_core_dims) > 1 and not isinstance(out, tuple):
                raise TypeError(
                    "output must be a tuple when multiple outputs are expected, "
                    "got: {!r}\n{}".format(out, error_context))
            if len(out_shapes) != len(output_core_dims):
                raise TypeError(
                    'wrong number of output arguments: expected %r, got %r %s'
                    % (len(output_core_dims), len(out_shapes), error_context))

        sizes = dict(dim_sizes)
        for shape, core_dims in zip(out_shapes, output_core_dims):
            _update_dim_sizes(sizes,
                              shape,
                              core_dims,
                              error_context,
                              is_input=False)

        return out
Ejemplo n.º 2
0
    def wrapped(*args):
        error_context = ("on vectorized function with excluded={!r} and "
                         "signature={!r}".format(excluded, signature))
        excluded_func, args = _apply_excluded(pyfunc, excluded, args)
        args = tuple(map(jnp.asarray, args))

        if signature is not None:
            input_core_dims, output_core_dims = _parse_gufunc_signature(
                signature)
        else:
            input_core_dims = [()] * len(args)
            output_core_dims = None

        broadcast_shape, dim_sizes = _parse_input_dimensions(
            args, input_core_dims, error_context)

        checked_func = _check_output_dims(excluded_func, dim_sizes,
                                          output_core_dims, error_context)

        # Rather than broadcasting all arguments to full broadcast shapes, prefer
        # expanding dimensions using vmap when possible. By pushing broadcasting
        # into vmap, we can make use of more efficient batching rules for
        # primitives where only some arguments are batched (e.g., for
        # lax_linalg.triangular_solve).

        vec_args = []
        vmap_counts = []

        for arg, core_dims in zip(args, input_core_dims):
            # Explicitly broadcast the dimensions already found on each argument,
            # because these dimensiosns might be of size 1, which vmap doesn't
            # handle.
            # TODO(shoyer): Consider squeezing out size 1 dimensions instead, and
            # doing all vectorization with vmap? This *might* be a little more
            # efficient but would require more careful book-keeping.
            core_shape = tuple(dim_sizes[dim] for dim in core_dims)
            full_shape = broadcast_shape + core_shape
            vec_shape = full_shape[-arg.ndim:] if arg.ndim else ()

            vec_arg = jnp.broadcast_to(arg, vec_shape)
            vec_args.append(vec_arg)

            vmap_count = len(vec_shape) - len(core_shape)
            vmap_counts.append(vmap_count)

        vectorized_func = checked_func
        while any(vmap_counts):
            in_axes = tuple(0 if c > 0 else None for c in vmap_counts)
            vmap_counts = [max(c - 1, 0) for c in vmap_counts]
            vectorized_func = api.vmap(vectorized_func, in_axes)
        return vectorized_func(*vec_args)
Ejemplo n.º 3
0
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
    """Use Conjugate Gradient iteration to solve ``Ax = b``.

  The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
  numerical precision), but note that the interface is slightly different: you
  need to supply the linear operator ``A`` as a function instead of a sparse
  matrix or ``LinearOperator``.

  Derivatives of ``cg`` are implemented via implicit differentiation with
  another ``cg`` solve, rather than by differentiating *through* the solver.
  They will be accurate only if both solves converge.

  Parameters
  ----------
  A : function
      Function that calculates the matrix-vector product ``Ax`` when called
      like ``A(x)``. ``A`` must represent a hermitian, positive definite
      matrix, and must return array(s) with the same structure and shape as its
      argument.
  b : array or tree of arrays
      Right hand side of the linear system representing a single vector. Can be
      stored as an array or Python container of array(s) with any shape.

  Returns
  -------
  x : array or tree of arrays
      The converged solution. Has the same structure as ``b``.
  info : None
      Placeholder for convergence information. In the future, JAX will report
      the number of iterations when convergence is not achieved, like SciPy.

  Other Parameters
  ----------------
  x0 : array
      Starting guess for the solution. Must have the same structure as ``b``.
  tol, atol : float, optional
      Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
      We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
      differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
  maxiter : integer
      Maximum number of iterations.  Iteration will stop after maxiter
      steps even if the specified tolerance has not been achieved.
  M : function
      Preconditioner for A.  The preconditioner should approximate the
      inverse of A.  Effective preconditioning dramatically improves the
      rate of convergence, which implies that fewer iterations are needed
      to reach a given error tolerance.

  See also
  --------
  scipy.sparse.linalg.cg
  jax.lax.custom_linear_solve
  """
    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)

    b, x0 = device_put((b, x0))

    if maxiter is None:
        size = sum(bi.size for bi in tree_leaves(b))
        maxiter = 10 * size  # copied from scipy

    if M is None:
        M = _identity

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    if _shapes(x0) != _shapes(b):
        raise ValueError('arrays in x0 and b must have matching shapes: '
                         f'{_shapes(x0)} vs {_shapes(b)}')

    cg_solve = partial(_cg_solve,
                       x0=x0,
                       tol=tol,
                       atol=atol,
                       maxiter=maxiter,
                       M=M)

    # real-valued positive-definite linear operators are symmetric
    def real_valued(x):
        return not issubclass(x.dtype.type, np.complexfloating)

    symmetric = all(map(real_valued, tree_leaves(b)))
    x = lax.custom_linear_solve(A,
                                b,
                                solve=cg_solve,
                                transpose_solve=cg_solve,
                                symmetric=symmetric)
    info = None  # TODO(shoyer): return the real iteration count here
    return x, info
Ejemplo n.º 4
0
def _norm(x):
    xs = tree_leaves(x)
    return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
Ejemplo n.º 5
0
def _shapes(pytree):
    return map(jnp.shape, tree_leaves(pytree))