Beispiel #1
0
    def wrapped(*args):
        out = func(*args)
        out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out])

        if expected_output_core_dims is None:
            output_core_dims = [()] * len(out_shapes)
        else:
            output_core_dims = expected_output_core_dims
            if len(output_core_dims) > 1 and not isinstance(out, tuple):
                raise TypeError(
                    "output must be a tuple when multiple outputs are expected, "
                    "got: {!r}\n{}".format(out, error_context))
            if len(out_shapes) != len(output_core_dims):
                raise TypeError(
                    'wrong number of output arguments: expected %r, got %r %s'
                    % (len(output_core_dims), len(out_shapes), error_context))

        sizes = dict(dim_sizes)
        for shape, core_dims in zip(out_shapes, output_core_dims):
            _update_dim_sizes(sizes,
                              shape,
                              core_dims,
                              error_context,
                              is_input=False)

        return out
Beispiel #2
0
def _isolve(_isolve_solve,
            A,
            b,
            x0=None,
            *,
            tol=1e-5,
            atol=0.0,
            maxiter=None,
            M=None,
            check_symmetric=False):
    if x0 is None:
        x0 = tree_map(jnp.zeros_like, b)

    b, x0 = device_put((b, x0))

    if maxiter is None:
        size = sum(bi.size for bi in tree_leaves(b))
        maxiter = 10 * size  # copied from scipy

    if M is None:
        M = _identity
    A = _normalize_matvec(A)
    M = _normalize_matvec(M)

    if tree_structure(x0) != tree_structure(b):
        raise ValueError('x0 and b must have matching tree structure: '
                         f'{tree_structure(x0)} vs {tree_structure(b)}')

    if _shapes(x0) != _shapes(b):
        raise ValueError('arrays in x0 and b must have matching shapes: '
                         f'{_shapes(x0)} vs {_shapes(b)}')

    isolve_solve = partial(_isolve_solve,
                           x0=x0,
                           tol=tol,
                           atol=atol,
                           maxiter=maxiter,
                           M=M)

    # real-valued positive-definite linear operators are symmetric
    def real_valued(x):
        return not issubclass(x.dtype.type, np.complexfloating)
    symmetric = all(map(real_valued, tree_leaves(b))) \
      if check_symmetric else False
    x = lax.custom_linear_solve(A,
                                b,
                                solve=isolve_solve,
                                transpose_solve=isolve_solve,
                                symmetric=symmetric)
    info = None
    return x, info
Beispiel #3
0
    def wrapped(*args):
        error_context = ("on vectorized function with excluded={!r} and "
                         "signature={!r}".format(excluded, signature))
        excluded_func, args = _apply_excluded(pyfunc, excluded, args)
        args = tuple(map(jnp.asarray, args))

        if signature is not None:
            input_core_dims, output_core_dims = _parse_gufunc_signature(
                signature)
        else:
            input_core_dims = [()] * len(args)
            output_core_dims = None

        broadcast_shape, dim_sizes = _parse_input_dimensions(
            args, input_core_dims, error_context)

        checked_func = _check_output_dims(excluded_func, dim_sizes,
                                          output_core_dims, error_context)

        # Rather than broadcasting all arguments to full broadcast shapes, prefer
        # expanding dimensions using vmap when possible. By pushing broadcasting
        # into vmap, we can make use of more efficient batching rules for
        # primitives where only some arguments are batched (e.g., for
        # lax_linalg.triangular_solve).

        vec_args = []
        vmap_counts = []

        for arg, core_dims in zip(args, input_core_dims):
            # Explicitly broadcast the dimensions already found on each argument,
            # because these dimensiosns might be of size 1, which vmap doesn't
            # handle.
            # TODO(shoyer): Consider squeezing out size 1 dimensions instead, and
            # doing all vectorization with vmap? This *might* be a little more
            # efficient but would require more careful book-keeping.
            core_shape = tuple(dim_sizes[dim] for dim in core_dims)
            full_shape = broadcast_shape + core_shape
            vec_shape = full_shape[-arg.ndim:] if arg.ndim else ()

            vec_arg = jnp.broadcast_to(arg, vec_shape)
            vec_args.append(vec_arg)

            vmap_count = len(vec_shape) - len(core_shape)
            vmap_counts.append(vmap_count)

        vectorized_func = checked_func
        while any(vmap_counts):
            in_axes = tuple(0 if c > 0 else None for c in vmap_counts)
            vmap_counts = [max(c - 1, 0) for c in vmap_counts]
            vectorized_func = api.vmap(vectorized_func, in_axes)
        return vectorized_func(*vec_args)
Beispiel #4
0
def _norm(x):
    xs = tree_leaves(x)
    return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
Beispiel #5
0
def _shapes(pytree):
    return map(jnp.shape, tree_leaves(pytree))