def funm(A, func, disp=True): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected square array_like input') T, Z = schur(A) T, Z = rsf2csf(T, Z) F = jnp.diag(func(jnp.diag(T))) F = F.astype(T.dtype.char) F, minden = _algorithm_11_1_1(F, T) F = Z @ F @ Z.conj().T if disp: return F if F.dtype.char.lower() == 'e': tol = jnp.finfo(jnp.float16).eps if F.dtype.char.lower() == 'f': tol = jnp.finfo(jnp.float32).eps else: tol = jnp.finfo(jnp.float64).eps minden = jnp.where(minden == 0.0, tol, minden) err = jnp.where(jnp.any(jnp.isinf(F)), jnp.inf, jnp.minimum(1, jnp.maximum( tol, (tol / minden) * norm(jnp.triu(T, 1), 1)))) return F, err
def rsf2csf(T, Z, check_finite=True): T = jnp.asarray(T) Z = jnp.asarray(Z) for ind, X in enumerate([Z, T]): if X.ndim != 2 or X.shape[0] != X.shape[1]: arg = 'ZT'[ind] raise ValueError(f"Input '{arg}' must be square.") if T.shape[0] != Z.shape[0]: raise ValueError( f"Input array shapes must match: Z: {Z.shape} vs. T: {T.shape}") T, Z = _promote_dtypes_complex(T, Z) eps = jnp.finfo(T.dtype).eps N = T.shape[0] if N == 1: return T, Z def _update_T_Z(m, T, Z): mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1), (2, 2))) - T[m, m] r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype) c = mu[0] / r s = T[m, m - 1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:] T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0) col_mask = jnp.arange(N) >= m - 1 G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0) T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols) T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0) # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1) row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1 T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH) T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1) # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1) Z = lax.dynamic_update_slice_in_dim(Z, Z_cols @ G.conj().T, m - 1, axis=1) return T, Z def _rsf2scf_iter(i, TZ): m = N - i T, Z = TZ T, Z = lax.cond( jnp.abs(T[m, m - 1]) > eps * (jnp.abs(T[m - 1, m - 1]) + jnp.abs(T[m, m])), _update_T_Z, lambda m, T, Z: (T, Z), m, T, Z) T = T.at[m, m - 1].set(0.0) return T, Z return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z))
def zeta(x, q=None): assert q is not None, "Riemann zeta function is not implemented yet." # Reference: Johansson, Fredrik. # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives." # Numerical Algorithms 69.2 (2015): 253-270. # https://arxiv.org/abs/1309.2877 - formula (5) # here we keep the same notation as in reference s, a = _promote_args_inexact("zeta", x, q) dtype = lax.dtype(a).type s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1) # precision ~ N, M N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16) assert M <= len(_BERNOULLI_COEFS) k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim))) S = jnp.sum((a_ + k)**-s_, -1) I = lax.div((a + N)**(dtype(1) - s), s - dtype(1)) T0 = (a + N)**-s m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim))) s_over_a = (s_ + m) / (a_ + N) T1 = jnp.cumprod(s_over_a, -1)[..., ::2] T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max) coefs = np.expand_dims( np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype), tuple(range(a.ndim))) T1 = T1 / coefs T = T0 * (dtype(0.5) + T1.sum(-1)) return S + I + T
def matrix_rank(M, tol=None): M = _promote_arg_dtypes(jnp.asarray(M)) if M.ndim > 2: raise TypeError("array should have 2 or fewer dimensions") if M.ndim < 2: return jnp.any(M != 0).astype(jnp.int32) S = svd(M, full_matrices=False, compute_uv=False) if tol is None: tol = S.max() * np.max(M.shape) * jnp.finfo(S.dtype).eps return jnp.sum(S > tol)
def _lstsq(a, b, rcond, *, numpy_resid=False): # TODO: add lstsq to lax_linalg and implement this function via those wrappers. # TODO: add custom jvp rule for more robust lstsq differentiation a, b = _promote_arg_dtypes(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") b_orig_ndim = b.ndim if b_orig_ndim == 1: b = b[:, None] if a.ndim != 2: raise TypeError( f"{a.ndim}-dimensional array given. Array must be two-dimensional") if b.ndim != 2: raise TypeError( f"{b.ndim}-dimensional array given. Array must be one or two-dimensional" ) m, n = a.shape dtype = a.dtype if rcond is None: rcond = jnp.finfo(dtype).eps * max(n, m) else: rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) u, s, vt = svd(a, full_matrices=False) mask = s >= rcond * s[0] rank = mask.sum() safe_s = jnp.where(mask, s, 1) s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): resid = jnp.asarray([]) else: b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) resid = norm(b - b_estimate, axis=0)**2 if b_orig_ndim == 1: x = x.ravel() return x, resid, rank, s
def _expn2(n, x): # x > 1. _c = _constant_like BIG = _c(x, 1.44115188075855872e17) MACHEP = jnp.finfo(BIG.dtype).eps # ? zero = _c(x, 0.0) one = _c(x, 1.0) init = dict( k=_c(n, 1), pkm2=one, qkm2=x, pkm1=one, qkm1=x + n, ans=one / (x + n), t=_c(x, jnp.inf), r=zero, x=x, ) def body(d): x = d["x"] d["k"] += _c(d["k"], 1) k = d["k"] odd = k % _c(k, 2) == _c(k, 1) yk = jnp.where(odd, one, x) xk = jnp.where(odd, n + (k - _c(k, 1)) / _c(k, 2), k / _c(k, 2)) pk = d["pkm1"] * yk + d["pkm2"] * xk qk = d["qkm1"] * yk + d["qkm2"] * xk nz = qk != zero d["r"] = r = jnp.where(nz, pk / qk, d["r"]) d["t"] = jnp.where(nz, abs((d["ans"] - r) / r), one) d["ans"] = jnp.where(nz, r, d["ans"]) d["pkm2"] = d["pkm1"] d["pkm1"] = pk d["qkm2"] = d["qkm1"] d["qkm1"] = qk is_big = abs(pk) > BIG for s in "pq": for i in "12": key = s + "km" + i d[key] = jnp.where(is_big, d[key] / BIG, d[key]) return d def cond(d): return (d["x"] > _c(d["k"], 0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) return d["ans"] * jnp.exp(-x)
def pinv(a, rcond=None): # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 a = jnp.conj(a) if rcond is None: max_rows_cols = max(a.shape[-2:]) rcond = 10. * max_rows_cols * jnp.finfo(a.dtype).eps rcond = jnp.asarray(rcond) u, s, vh = svd(a, full_matrices=False) # Singular values less than or equal to ``rcond * largest_singular_value`` # are set to zero. cutoff = rcond[..., jnp.newaxis] * jnp.amax( s, axis=-1, keepdims=True, initial=-jnp.inf) s = jnp.where(s > cutoff, s, jnp.inf) res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis])) return lax.convert_element_type(res, a.dtype)
def polydiv(u, v, *, trim_leading_zeros=False): _check_arraylike("polydiv", u, v) u, v = _promote_dtypes_inexact(u, v) m = len(u) - 1 n = len(v) - 1 scale = 1. / v[0] q = zeros(max(m - n + 1, 1), dtype=u.dtype) # force same dtype for k in range(0, m - n + 1): d = scale * u[k] q = q.at[k].set(d) u = u.at[k:k + n + 1].add(-d * v) if trim_leading_zeros: # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f') else: return q, u
def _expn1(n, x): # exponential integral En _c = _constant_like x = jnp.array(x) MACHEP = jnp.finfo(x.dtype).eps zero = _c(x, 0.0) one = _c(x, 1.0) psi = -jnp.euler_gamma - jnp.log(x) psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi) n1 = jnp.where(n == _c(n, 1), one + one, n) init = dict( x=x, z=-x, xk=zero, yk=one, pk=one - n, ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)), t=jnp.inf, ) def body(d): d["xk"] += one d["yk"] *= d["z"] / d["xk"] d["pk"] += one d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero) d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one) return d def cond(d): return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) t = n r = n - _c(n, 1) return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]
def norm(x, ord=None, axis: Union[None, Tuple[int, ...], int] = None, keepdims=False): x = _promote_arg_dtypes(jnp.asarray(x)) x_shape = jnp.shape(x) ndim = len(x_shape) if axis is None: # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) else: axis = (canonicalize_axis(axis, ndim), ) num_axes = len(axis) if num_axes == 1: if ord is None or ord == 2: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == jnp.inf: return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == -jnp.inf: return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) else: abs_x = jnp.abs(x) ord = lax._const(abs_x, ord) out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims) return jnp.power(out, 1. / ord) elif num_axes == 2: row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord == -jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) if ord == 2: reducer = jnp.amax elif ord == -2: reducer = jnp.amin else: reducer = jnp.sum y = reducer(svd(x, compute_uv=False), axis=-1) if keepdims: result_shape = list(x_shape) result_shape[axis[0]] = 1 result_shape[axis[1]] = 1 y = jnp.reshape(y, result_shape) return y else: raise ValueError("Invalid order '{}' for matrix norm.".format(ord)) else: raise ValueError( "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): _check_arraylike("polyfit", x, y) deg = core.concrete_or_error(int, deg, "deg must be int") order = deg + 1 # check arguments if deg < 0: raise ValueError("expected deg >= 0") if x.ndim != 1: raise TypeError("expected 1D vector for x") if x.size == 0: raise TypeError("expected non-empty vector for x") if y.ndim < 1 or y.ndim > 2: raise TypeError("expected 1D or 2D array for y") if x.shape[0] != y.shape[0]: raise TypeError("expected x and y to have same length") # set rcond if rcond is None: rcond = len(x) * finfo(x.dtype).eps rcond = core.concrete_or_error(float, rcond, "rcond must be float") # set up least squares equation for powers of x lhs = vander(x, order) rhs = y # apply weighting if w is not None: _check_arraylike("polyfit", w) w, = _promote_dtypes_inexact(w) if w.ndim != 1: raise TypeError("expected a 1-d array for weights") if w.shape[0] != y.shape[0]: raise TypeError("expected w and y to have the same length") lhs *= w[:, np.newaxis] if rhs.ndim == 2: rhs *= w[:, np.newaxis] else: rhs *= w # scale lhs to improve condition number and solve scale = sqrt((lhs * lhs).sum(axis=0)) lhs /= scale[np.newaxis, :] c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond) c = (c.T / scale).T # broadcast scale coefficients if full: return c, resids, rank, s, rcond elif cov: Vbase = linalg.inv(dot(lhs.T, lhs)) Vbase /= outer(scale, scale) if cov == "unscaled": fac = 1 else: if len(x) <= order: raise ValueError("the number of data points must exceed order " "to scale the covariance matrix") fac = resids / (len(x) - order) fac = fac[0] #making np.array() of shape (1,) to int if y.ndim == 1: return c, Vbase * fac else: return c, Vbase[:, :, np.newaxis] * fac else: return c
def _projector_subspace(P, H, n, rank, maxiter=2): """ Decomposes the `n x n` rank `rank` Hermitian projector `P` into an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T` and an `n x (n - rank)` isometry `V_minus` such that -(I - P) = V_plus @ V_plus.conj().T`. The subspaces are computed using the naiive QR eigendecomposition algorithm, which converges very quickly due to the sharp separation between the relevant eigenvalues of the projector. Args: P: A rank-`rank` Hermitian projector into the space of `H`'s first `rank` eigenpairs. `P` is padded to NxN. H: The aforementioned Hermitian matrix, which is used to track convergence. n: the true (dynamic) shape of `P`. rank: Rank of `P`. maxiter: Maximum number of iterations. Returns: V_minus, V_plus: Isometries into the eigenspaces described in the docstring. """ # Choose an initial guess: the `rank` largest-norm columns of P. N, _ = P.shape column_norms = jnp_linalg.norm(P, axis=1) # `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN. column_norms = _mask(column_norms, (n,), jnp.nan) sort_idxs = jnp.argsort(column_norms) X = P[:, sort_idxs] # X = X[:, :rank] X = _mask(X, (n, rank)) H_norm = jnp_linalg.norm(H) thresh = 10 * jnp.finfo(X.dtype).eps * H_norm # First iteration skips the matmul. def body_f_after_matmul(X): Q, _ = jnp_linalg.qr(X, mode="complete") # V1 = Q[:, :rank] # V2 = Q[:, rank:] V1 = _mask(Q, (n, rank)) V2 = _slice(Q, (0, rank), (n, n - rank), (N, N)) # TODO: might be able to get away with lower precision here error_matrix = jnp.dot(V2.conj().T, H) error_matrix = jnp.dot(error_matrix, V1) error = jnp_linalg.norm(error_matrix) / H_norm return V1, V2, error def cond_f(args): _, _, j, error = args still_counting = j < maxiter unconverged = error > thresh return jnp.logical_and(still_counting, unconverged)[0] def body_f(args): V1, _, j, _ = args X = jnp.dot(P, V1) V1, V2, error = body_f_after_matmul(X) return V1, V2, j + 1, error V1, V2, error = body_f_after_matmul(X) one = jnp.ones(1, dtype=jnp.int32) V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error)) return V1, V2