def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, returned=False, keepdims=False): if weights is None: # Treat all weights as 1 _check_arraylike("average", a) a, = _promote_dtypes_inexact(a) avg = mean(a, axis=axis, keepdims=keepdims) if axis is None: weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) else: _check_arraylike("average", a, weights) a, weights = _promote_dtypes_inexact(a, weights) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) axis = None if axis is None else _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1, ) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, keepdims=keepdims) avg = sum(a * weights, axis=axis, keepdims=keepdims) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
def _resize(image, shape: core.Shape, method: Union[str, ResizeMethod], antialias: bool, precision): if len(shape) != image.ndim: msg = ( 'shape must have length equal to the number of dimensions of x; ' f' {shape} vs {image.shape}') raise ValueError(msg) if isinstance(method, str): method = ResizeMethod.from_string(method) if method == ResizeMethod.NEAREST: return _resize_nearest(image, shape) assert isinstance(method, ResizeMethod) kernel = _kernels[method] image, = _promote_dtypes_inexact(image) # Skip dimensions that have scale=1 and translation=0, this is only possible # since all of the current resize methods (kernels) are interpolating, so the # output = input under an identity warp. spatial_dims = tuple( i for i in range(len(shape)) if not core.symbolic_equal_dim(image.shape[i], shape[i])) scale = [ 1.0 if core.symbolic_equal_dim( shape[d], 0) else core.dimension_as_value(shape[d]) / core.dimension_as_value(image.shape[d]) for d in spatial_dims ] return _scale_and_translate(image, shape, spatial_dims, scale, [0.] * len(spatial_dims), kernel, antialias, precision)
def matrix_power(a, n): a, = _promote_dtypes_inexact(jnp.asarray(a)) if a.ndim < 2: raise TypeError("{}-dimensional array given. Array must be at least " "two-dimensional".format(a.ndim)) if a.shape[-2] != a.shape[-1]: raise TypeError("Last 2 dimensions of the array must be square") try: n = operator.index(n) except TypeError as err: raise TypeError("exponent must be an integer, got {}".format(n)) from err if n == 0: return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape) elif n < 0: a = inv(a) n = np.abs(n) if n == 1: return a elif n == 2: return a @ a elif n == 3: return (a @ a) @ a z = result = None while n > 0: z = a if z is None else (z @ z) n, bit = divmod(n, 2) if bit: result = z if result is None else (result @ z) return result
def polyval(p, x, *, unroll=16): _check_arraylike("polyval", p, x) p, x = _promote_dtypes_inexact(p, x) shape = lax.broadcast_shapes(p.shape[1:], x.shape) y = lax.full_like(x, 0, shape=shape, dtype=x.dtype) y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll) return y
def svd(a, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False): a, = _promote_dtypes_inexact(jnp.asarray(a)) if hermitian: w, v = lax_linalg.eigh(a) s = lax.abs(v) if compute_uv: sign = lax.sign(v) idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1) s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1) s = lax.rev(s, dimensions=[s.ndim - 1]) idxs = lax.rev(idxs, dimensions=[s.ndim - 1]) sign = lax.rev(sign, dimensions=[s.ndim - 1]) u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1) vh = _H(u * sign[..., None, :]) return u, s, vh else: return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1]) return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
def hypot(x1, x2): _check_arraylike("hypot", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) x1 = lax.abs(x1) x2 = lax.abs(x2) x1, x2 = maximum(x1, x2), minimum(x1, x2) return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
def _solve_triangular(a, b, trans, lower, unit_diagonal): if trans == 0 or trans == "N": transpose_a, conjugate_a = False, False elif trans == 1 or trans == "T": transpose_a, conjugate_a = True, False elif trans == 2 or trans == "C": transpose_a, conjugate_a = True, True else: raise ValueError(f"Invalid 'trans' value {trans}") a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) # lax_linalg.triangular_solve only supports matrix 'b's at the moment. b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1 if b_is_vector: b = b[..., None] out = lax_linalg.triangular_solve(a, b, left_side=True, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) if b_is_vector: return out[..., 0] else: return out
def modf(x, out=None): _check_arraylike("modf", x) x, = _promote_dtypes_inexact(x) if out is not None: raise NotImplementedError( "The 'out' argument to jnp.modf is not supported.") whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x)) return x - whole, whole
def sinc(x): _check_arraylike("sinc", x) x, = _promote_dtypes_inexact(x) eq_zero = lax.eq(x, _lax_const(x, 0)) pi_x = lax.mul(_lax_const(x, np.pi), x) safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x) return _where(eq_zero, _sinc_maclaurin(0, pi_x), lax.div(lax.sin(safe_pi_x), safe_pi_x))
def _roots_no_zeros(p): # assume: p does not have leading zeros and has length > 1 p, = _promote_dtypes_inexact(p) # build companion matrix and find its eigenvalues (the roots) A = diag(ones((p.size - 2, ), p.dtype), -1) A = A.at[0, :].set(-p[1:] / p[0]) roots = linalg.eigvals(A) return roots
def matrix_rank(M, tol=None): M, = _promote_dtypes_inexact(jnp.asarray(M)) if M.ndim > 2: raise TypeError("array should have 2 or fewer dimensions") if M.ndim < 2: return jnp.any(M != 0).astype(jnp.int32) S = svd(M, full_matrices=False, compute_uv=False) if tol is None: tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps return jnp.sum(S > tol)
def polyder(p, m=1): _check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") p, = _promote_dtypes_inexact(p) if m < 0: raise ValueError("Order of derivative must be positive") if m == 0: return p coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return p[:-m] * coeff
def polymul(a1, a2, *, trim_leading_zeros=False): _check_arraylike("polymul", a1, a2) a1, a2 = _promote_dtypes_inexact(a1, a2) if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1): a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f') if len(a1) == 0: a1 = asarray([0], dtype=a2.dtype) if len(a2) == 0: a2 = asarray([0], dtype=a1.dtype) return convolve(a1, a2, mode='full')
def eigh(a, UPLO=None, symmetrize_input=True): if UPLO is None or UPLO == "L": lower = True elif UPLO == "U": lower = False else: msg = "UPLO must be one of None, 'L', or 'U', got {}".format(UPLO) raise ValueError(msg) a, = _promote_dtypes_inexact(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input) return w, v
def qr(a, mode="reduced"): if mode in ("reduced", "r", "full"): full_matrices = False elif mode == "complete": full_matrices = True else: raise ValueError("Unsupported QR decomposition mode '{}'".format(mode)) a, = _promote_dtypes_inexact(jnp.asarray(a)) q, r = lax_linalg.qr(a, full_matrices) if mode == "r": return r return q, r
def det(a): a, = _promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2: return _det_2x2(a) elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3: return _det_3x3(a) elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]: sign, logdet = slogdet(a) return sign * jnp.exp(logdet) else: msg = "Argument to _det() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape))
def slogdet(a, *, method: Optional[str] = None): a, = _promote_dtypes_inexact(jnp.asarray(a)) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) if method is None or method == "lu": return _slogdet_lu(a) elif method == "qr": return _slogdet_qr(a) else: raise ValueError(f"Unknown slogdet method '{method}'. Supported methods " "are 'lu' (`None`), and 'qr'.")
def _lu(a, permute_l): a, = _promote_dtypes_inexact(jnp.asarray(a)) lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real( jnp.array(permutation[None, :] == jnp.arange(m)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def qr(a, mode="reduced"): a, = _promote_dtypes_inexact(jnp.asarray(a)) if mode == "raw": a, taus = lax_linalg.geqrf(a) return _T(a), taus if mode in ("reduced", "r", "full"): full_matrices = False elif mode == "complete": full_matrices = True else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") q, r = lax_linalg.qr(a, full_matrices=full_matrices) if mode == "r": return r return q, r
def vq(obs, code_book, check_finite=True): _check_arraylike("scipy.cluster.vq.vq", obs, code_book) if obs.ndim != code_book.ndim: raise ValueError("Observation and code_book should have the same rank") obs, code_book = _promote_dtypes_inexact(obs, code_book) if obs.ndim == 1: obs, code_book = obs[..., None], code_book[..., None] if obs.ndim != 2: raise ValueError("ndim different than 1 or 2 are not supported") # explicitly rank promotion dist = vmap(lambda ob: norm(ob[None] - code_book, axis=-1))(obs) code = argmin(dist, axis=-1) dist_min = vmap(operator.getitem)(dist, code) return code, dist_min
def _qr(a, mode, pivoting): if pivoting: raise NotImplementedError( "The pivoting=True case of qr is not implemented.") if mode in ("full", "r"): full_matrices = True elif mode == "economic": full_matrices = False else: raise ValueError(f"Unsupported QR decomposition mode '{mode}'") a, = _promote_dtypes_inexact(jnp.asarray(a)) q, r = lax_linalg.qr(a, full_matrices=full_matrices) if mode == "r": return (r, ) return q, r
def polydiv(u, v, *, trim_leading_zeros=False): _check_arraylike("polydiv", u, v) u, v = _promote_dtypes_inexact(u, v) m = len(u) - 1 n = len(v) - 1 scale = 1. / v[0] q = zeros(max(m - n + 1, 1), dtype=u.dtype) # force same dtype for k in range(0, m - n + 1): d = scale * u[k] q = q.at[k].set(d) u = u.at[k:k + n + 1].add(-d * v) if trim_leading_zeros: # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f') else: return q, u
def _cho_solve(c, b, lower): c, b = _promote_dtypes_inexact(jnp.asarray(c), jnp.asarray(b)) lax_linalg._check_solve_shapes(c, b) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=not lower, conjugate_a=not lower) b = lax_linalg.triangular_solve(c, b, left_side=True, lower=lower, transpose_a=lower, conjugate_a=lower) return b
def polyint(p, m=1, k=None): m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k _check_arraylike("polyint", p, k) p, k = _promote_dtypes_inexact(p, k) if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k = atleast_1d(k) if len(k) == 1: k = full((m,), k[0]) if k.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: return p else: coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return true_divide(concatenate((p, k)), coeff)
def roots(p, *, strip_zeros=True): _check_arraylike("roots", p) p = atleast_1d(*_promote_dtypes_inexact(p)) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p.size < 2: return array([], dtype=dtypes._to_complex_dtype(p.dtype)) num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) if strip_zeros: num_leading_zeros = core.concrete_or_error( int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " "will be result in some returned roots being set to NaN.") return _roots_no_zeros(p[num_leading_zeros:]) else: return _roots_with_zeros(p, num_leading_zeros)
def _eigh(a, b, lower, eigvals_only, eigvals, type): if b is not None: raise NotImplementedError( "Only the b=None case of eigh is implemented") if type != 1: raise NotImplementedError( "Only the type=1 case of eigh is implemented.") if eigvals is not None: raise NotImplementedError( "Only the eigvals=None case of eigh is implemented.") a, = _promote_dtypes_inexact(jnp.asarray(a)) v, w = lax_linalg.eigh(a, lower=lower) if eigvals_only: return w else: return w, v
def frexp(x): _check_arraylike("frexp", x) x, = _promote_dtypes_inexact(x) if dtypes.issubdtype(x.dtype, np.complexfloating): raise TypeError("frexp does not support complex-valued inputs") dtype = dtypes.dtype(x) info = dtypes.finfo(dtype) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x1, x2 = _normalize_float(x) x2 += ((x1 >> info.nmant) & mask) - bias + 1 x1 &= ~(mask << info.nmant) x1 |= (bias - 1) << info.nmant x1 = lax.bitcast_convert_type(x1, dtype) cond = isinf(x) | isnan(x) | (x == 0) x2 = _where(cond, lax_internal._zeros(x2), x2) return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]): if isinstance(fft_type, str): typ = _str_to_fft_type(fft_type) elif isinstance(fft_type, xla_client.FftType): typ = fft_type else: raise TypeError(f"Unknown FFT type value '{fft_type}'") if typ == xla_client.FftType.RFFT: if np.iscomplexobj(x): raise ValueError("only real valued inputs supported for rfft") x, = _promote_dtypes_inexact(x) else: x, = _promote_dtypes_complex(x) if len(fft_lengths) == 0: # XLA FFT doesn't support 0-rank. return x fft_lengths = tuple(fft_lengths) return fft_p.bind(x, fft_type=typ, fft_lengths=fft_lengths)
def _solve(a, b, sym_pos, lower): if not sym_pos: return np_linalg.solve(a, b) a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) lax_linalg._check_solve_shapes(a, b) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. factors = cho_factor(lax.stop_gradient(a), lower=lower) custom_solve = partial(lax.custom_linear_solve, lambda x: lax_linalg._matvec_multiply(a, x), solve=lambda _, x: cho_solve(factors, x), symmetric=True) if a.ndim == b.ndim + 1: # b.shape == [..., m] return custom_solve(b) else: # b.shape == [..., m, k] return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b)
def initial_step_size(fun, t0, y0, order, rtol, atol, f0): # Algorithm from: # E. Hairer, S. P. Norsett G. Wanner, # Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4. y0, f0 = _promote_dtypes_inexact(y0, f0) dtype = y0.dtype scale = atol + jnp.abs(y0) * rtol d0 = jnp.linalg.norm(y0 / scale.astype(dtype)) d1 = jnp.linalg.norm(f0 / scale.astype(dtype)) h0 = jnp.where((d0 < 1e-5) | (d1 < 1e-5), 1e-6, 0.01 * d0 / d1) y1 = y0 + h0.astype(dtype) * f0 f1 = fun(y1, t0 + h0) d2 = jnp.linalg.norm((f1 - f0) / scale.astype(dtype)) / h0 h1 = jnp.where((d1 <= 1e-15) & (d2 <= 1e-15), jnp.maximum(1e-6, h0 * 1e-3), (0.01 / jnp.max(d1 + d2)) ** (1. / (order + 1.))) return jnp.minimum(100. * h0, h1)