def eigh_abstract_eval(operand, lower): if isinstance(operand, ShapedArray): if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]: raise ValueError( "Argument to symmetric eigendecomposition must have shape [..., n, n]," "got shape {}".format(operand.shape)) batch_dims = operand.shape[:-2] n = operand.shape[-1] v = ShapedArray(batch_dims + (n, n), operand.dtype) w = ShapedArray(batch_dims + (n, ), lax_internal._complex_basetype(operand.dtype)) else: v, w = operand, operand return v, w
def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] s = jnp.empty(batch_shape + (0,), dtype=lax_internal._complex_basetype(a.dtype)) if not compute_uv: return (s,) if full_matrices: size = max(m, n) u = jnp.broadcast_to(jnp.eye(size, dtype=a.dtype), batch_shape + (size, size)) else: u = jnp.empty(batch_shape + (m, n), dtype=a.dtype) v = jnp.empty(batch_shape + (0, 0), dtype=a.dtype) if m < n: u, v = v, u return s, u, v
def svd_abstract_eval(operand, full_matrices, compute_uv): if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to singular value decomposition must have ndims >= 2") batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] s = ShapedArray(batch_dims + (min(m, n),), lax_internal._complex_basetype(operand.dtype)) if compute_uv: u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype) vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype) return s, u, vt else: return s, else: raise NotImplementedError