def body(k, state): pivot, perm, a = state m_idx = jnp.arange(m) n_idx = jnp.arange(n) if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where( (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) return pivot, perm, a
def _rsf2scf_iter(i, TZ): m = N - i T, Z = TZ T, Z = lax.cond( jnp.abs(T[m, m - 1]) > eps * (jnp.abs(T[m - 1, m - 1]) + jnp.abs(T[m, m])), _update_T_Z, lambda m, T, Z: (T, Z), m, T, Z) T = T.at[m, m - 1].set(0.0) return T, Z
def _algorithm_11_1_1(F, T): # Algorithm 11.1.1 from Golub and Van Loan "Matrix Computations" N = T.shape[0] minden = jnp.abs(T[0, 0]) def _outer_loop(p, F_minden): _, F, minden = lax.fori_loop(1, N-p+1, _inner_loop, (p, *F_minden)) return F, minden def _inner_loop(i, p_F_minden): p, F, minden = p_F_minden j = i+p s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1]) T_row, T_col = T[i-1], T[:, j-1] F_row, F_col = F[i-1], F[:, j-1] ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1) val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) - jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0)) s = s + val den = T[j-1, j-1] - T[i-1, i-1] s = jnp.where(den != 0, s / den, s) F = F.at[i-1, j-1].set(s) minden = jnp.minimum(minden, jnp.abs(den)) return p, F, minden return lax.fori_loop(1, N, _outer_loop, (F, minden))
def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _inner_loop(i, p_F_minden): p, F, minden = p_F_minden j = i+p s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1]) T_row, T_col = T[i-1], T[:, j-1] F_row, F_col = F[i-1], F[:, j-1] ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1) val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) - jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0)) s = s + val den = T[j-1, j-1] - T[i-1, i-1] s = jnp.where(den != 0, s / den, s) F = F.at[i-1, j-1].set(s) minden = jnp.minimum(minden, jnp.abs(den)) return p, F, minden
def _slogdet_qr(a): # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det
def norm(x, ord=None, axis: Union[None, Tuple[int, ...], int] = None, keepdims=False): x = _promote_arg_dtypes(jnp.asarray(x)) x_shape = jnp.shape(x) ndim = len(x_shape) if axis is None: # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) else: axis = (canonicalize_axis(axis, ndim), ) num_axes = len(axis) if num_axes == 1: if ord is None or ord == 2: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == jnp.inf: return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == -jnp.inf: return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) else: abs_x = jnp.abs(x) ord = lax._const(abs_x, ord) out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims) return jnp.power(out, 1. / ord) elif num_axes == 2: row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord == -jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) if ord == 2: reducer = jnp.amax elif ord == -2: reducer = jnp.amin else: reducer = jnp.sum y = reducer(svd(x, compute_uv=False), axis=-1) if keepdims: result_shape = list(x_shape) result_shape[axis[0]] = 1 result_shape[axis[1]] = 1 y = jnp.reshape(y, result_shape) return y else: raise ValueError("Invalid order '{}' for matrix norm.".format(ord)) else: raise ValueError( "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None): if not eigvals_only: raise NotImplementedError( "Calculation of eigenvectors is not implemented") def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count alpha = jnp.asarray(d) beta = jnp.asarray(e) supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) if alpha.dtype != beta.dtype: raise TypeError( "diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes: raise TypeError( "Only float32 and float64 inputs are supported as inputs " "to jax.scipy.linalg.eigh_tridiagonal, got " f"{alpha.dtype} and {beta.dtype}") n = alpha.shape[0] if n <= 1: return jnp.real(alpha) if jnp.issubdtype(alpha.dtype, jnp.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) else: beta_abs = jnp.abs(beta) beta_sq = jnp.square(beta) # Estimate the largest and smallest eigenvalues of T using the Gershgorin # circle theorem. off_diag_abs_row_sum = jnp.concatenate( [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) # Upper bound on 2-norm of T. t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) # Compute the smallest allowed pivot in the Sturm sequence to avoid # overflow. finfo = np.finfo(alpha.dtype) one = np.ones([], dtype=alpha.dtype) safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) abs_tol = finfo.eps * t_norm if tol is not None: abs_tol = jnp.maximum(tol, abs_tol) # In the worst case, when the absolute tolerance is eps*lambda_est_max and # lambda_est_max = -lambda_est_min, we have to take as many bisection steps # as there are bits in the mantissa plus 1. # The proof is left as an exercise to the reader. max_it = finfo.nmant + 1 # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': target_counts = jnp.arange(n, dtype=jnp.int32) elif select == 'i': if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " "implemented") else: raise ValueError("'select must have a value in {'a', 'i', 'v'}.") # Run binary search for all desired eigenvalues in parallel, starting from # the interval lightly wider than the estimated # [lambda_est_min, lambda_est_max]. fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm lower = lambda_est_min - norm_slack - 2 * fudge * pivmin upper = lambda_est_max + norm_slack + fudge * pivmin # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. target_shape = jnp.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) # Start parallel binary searches. def cond(args): i, lower, _, upper = args return jnp.logical_and(jnp.less(i, max_it), jnp.less(abs_tol, jnp.amax(upper - lower))) def body(args): i, lower, mid, upper = args counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) lower = jnp.where(counts <= target_counts, mid, lower) upper = jnp.where(counts > target_counts, mid, upper) mid = 0.5 * (lower + upper) return i + 1, lower, mid, upper _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid
def _mirror_index_fixer(index, size): s = size - 1 # Half-wavelength of triangular wave # Scaled, integer-valued version of the triangular wave |x - round(x)| return jnp.abs((index + s) % (2 * s) - s)