def wrapped(*args): out = func(*args) out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out]) if expected_output_core_dims is None: output_core_dims = [()] * len(out_shapes) else: output_core_dims = expected_output_core_dims if len(output_core_dims) > 1 and not isinstance(out, tuple): raise TypeError( "output must be a tuple when multiple outputs are expected, " "got: {!r}\n{}".format(out, error_context)) if len(out_shapes) != len(output_core_dims): raise TypeError( 'wrong number of output arguments: expected %r, got %r %s' % (len(output_core_dims), len(out_shapes), error_context)) sizes = dict(dim_sizes) for shape, core_dims in zip(out_shapes, output_core_dims): _update_dim_sizes(sizes, shape, core_dims, error_context, is_input=False) return out
def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, check_symmetric=False): if x0 is None: x0 = tree_map(jnp.zeros_like, b) b, x0 = device_put((b, x0)) if maxiter is None: size = sum(bi.size for bi in tree_leaves(b)) maxiter = 10 * size # copied from scipy if M is None: M = _identity A = _normalize_matvec(A) M = _normalize_matvec(M) if tree_structure(x0) != tree_structure(b): raise ValueError('x0 and b must have matching tree structure: ' f'{tree_structure(x0)} vs {tree_structure(b)}') if _shapes(x0) != _shapes(b): raise ValueError('arrays in x0 and b must have matching shapes: ' f'{_shapes(x0)} vs {_shapes(b)}') isolve_solve = partial(_isolve_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M) # real-valued positive-definite linear operators are symmetric def real_valued(x): return not issubclass(x.dtype.type, np.complexfloating) symmetric = all(map(real_valued, tree_leaves(b))) \ if check_symmetric else False x = lax.custom_linear_solve(A, b, solve=isolve_solve, transpose_solve=isolve_solve, symmetric=symmetric) info = None return x, info
def wrapped(*args): error_context = ("on vectorized function with excluded={!r} and " "signature={!r}".format(excluded, signature)) excluded_func, args = _apply_excluded(pyfunc, excluded, args) args = tuple(map(jnp.asarray, args)) if signature is not None: input_core_dims, output_core_dims = _parse_gufunc_signature( signature) else: input_core_dims = [()] * len(args) output_core_dims = None broadcast_shape, dim_sizes = _parse_input_dimensions( args, input_core_dims, error_context) checked_func = _check_output_dims(excluded_func, dim_sizes, output_core_dims, error_context) # Rather than broadcasting all arguments to full broadcast shapes, prefer # expanding dimensions using vmap when possible. By pushing broadcasting # into vmap, we can make use of more efficient batching rules for # primitives where only some arguments are batched (e.g., for # lax_linalg.triangular_solve). vec_args = [] vmap_counts = [] for arg, core_dims in zip(args, input_core_dims): # Explicitly broadcast the dimensions already found on each argument, # because these dimensiosns might be of size 1, which vmap doesn't # handle. # TODO(shoyer): Consider squeezing out size 1 dimensions instead, and # doing all vectorization with vmap? This *might* be a little more # efficient but would require more careful book-keeping. core_shape = tuple(dim_sizes[dim] for dim in core_dims) full_shape = broadcast_shape + core_shape vec_shape = full_shape[-arg.ndim:] if arg.ndim else () vec_arg = jnp.broadcast_to(arg, vec_shape) vec_args.append(vec_arg) vmap_count = len(vec_shape) - len(core_shape) vmap_counts.append(vmap_count) vectorized_func = checked_func while any(vmap_counts): in_axes = tuple(0 if c > 0 else None for c in vmap_counts) vmap_counts = [max(c - 1, 0) for c in vmap_counts] vectorized_func = api.vmap(vectorized_func, in_axes) return vectorized_func(*vec_args)
def _norm(x): xs = tree_leaves(x) return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
def _shapes(pytree): return map(jnp.shape, tree_leaves(pytree))