def convolve(in1, in2, mode='full', method='auto', precision=None): if method != 'auto': warnings.warn("convolve() ignores method argument") if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype( in2.dtype, jnp.complexfloating): raise NotImplementedError("convolve() does not support complex inputs") return _convolve_nd(in1, in2, mode, precision=precision)
def correlate(in1, in2, mode='full', method='auto', precision=None): if method != 'auto': warnings.warn("correlate() ignores method argument") if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype( in2.dtype, jnp.complexfloating): raise NotImplementedError( "correlate() does not support complex inputs") if jnp.ndim(in1) != 1 or jnp.ndim(in2) != 1: raise ValueError( "correlate() only supports {ndim}-dimensional inputs.") return _convolve_nd(in1, in2[::-1], mode, precision=precision)
def body(k, state): pivot, perm, a = state m_idx = jnp.arange(m) n_idx = jnp.arange(n) if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where( (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) return pivot, perm, a
def _get_identity(op, dtype): """Get an appropriate identity for a given operation in a given dtype.""" if op is lax.scatter_add: return 0 elif op is lax.scatter_mul: return 1 elif op is lax.scatter_min: if jnp.issubdtype(dtype, jnp.integer): return jnp.iinfo(dtype).max return float('inf') elif op is lax.scatter_max: if jnp.issubdtype(dtype, jnp.integer): return jnp.iinfo(dtype).min return -float('inf') else: raise ValueError(f"Unrecognized op: {op}")
def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0, precision=None): if boundary != 'fill' or fillvalue != 0: raise NotImplementedError( "convolve2d() only supports boundary='fill', fillvalue=0") if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype( in2.dtype, jnp.complexfloating): raise NotImplementedError( "convolve2d() does not support complex inputs") if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: raise ValueError("convolve2d() only supports 2-dimensional inputs.") return _convolve_nd(in1, in2, mode, precision=precision)
def _nan_like(c, operand): shape = c.get_shape(operand) dtype = shape.element_type() if jnp.issubdtype(dtype, np.complexfloating): nan = xb.constant(c, np.array(np.nan * (1. + 1j), dtype=dtype)) else: nan = xb.constant(c, np.array(np.nan, dtype=dtype)) return xops.Broadcast(nan, shape.dimensions())
def triangular_solve(a, b, left_side: bool = False, lower: bool = False, transpose_a: bool = False, conjugate_a: bool = False, unit_diagonal: bool = False): r"""Triangular solve. Solves either the matrix equation .. math:: \mathit{op}(A) . X = B if ``left_side`` is ``True`` or .. math:: X . \mathit{op}(A) = B if ``left_side`` is ``False``. ``A`` must be a lower or upper triangular square matrix, and where :math:`\mathit{op}(A)` may either transpose :math:`A` if ``transpose_a`` is ``True`` and/or take its complex conjugate if ``conjugate_a`` is ``True``. Args: a: A batch of matrices with shape ``[..., m, m]``. b: A batch of matrices with shape ``[..., m, n]`` if ``left_side`` is ``True`` or shape ``[..., n, m]`` otherwise. left_side: describes which of the two matrix equations to solve; see above. lower: describes which triangle of ``a`` should be used. The other triangle is ignored. transpose_a: if ``True``, the value of ``a`` is transposed. conjugate_a: if ``True``, the complex conjugate of ``a`` is used in the solve. Has no effect if ``a`` is real. unit_diagonal: if ``True``, the diagonal of ``a`` is assumed to be unit (all 1s) and not accessed. Returns: A batch of matrices the same shape and dtype as ``b``. """ conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating) singleton = jnp.ndim(b) == jnp.ndim(a) - 1 if singleton: b = jnp.expand_dims(b, -1 if left_side else -2) out = triangular_solve_p.bind(a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) if singleton: out = out[..., 0] if left_side else out[..., 0, :] return out
def _promote_arg_dtypes(*args): """Promotes `args` to a common inexact type.""" dtype, weak_type = dtypes._lattice_result_type(*args) if not jnp.issubdtype(dtype, jnp.inexact): dtype, weak_type = jnp.float_, False dtype = dtypes.canonicalize_dtype(dtype) args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args] if len(args) == 1: return args[0] else: return args
def _slogdet_jvp(primals, tangents): x, = primals g, = tangents sign, ans = slogdet(x) ans_dot = jnp.trace(solve(x, g), axis1=-1, axis2=-2) if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): sign_dot = (ans_dot - np.real(ans_dot)) * sign ans_dot = np.real(ans_dot) else: sign_dot = jnp.zeros_like(sign) return (sign, ans), (sign_dot, ans_dot)
def _map_coordinates(input, coordinates, order, mode, cval): input = jnp.asarray(input) coordinates = [jnp.asarray(c) for c in coordinates] cval = jnp.asarray(cval, input.dtype) if len(coordinates) != input.ndim: raise ValueError( 'coordinates must be a sequence of length input.ndim, but ' '{} != {}'.format(len(coordinates), input.ndim)) index_fixer = _INDEX_FIXERS.get(mode) if index_fixer is None: raise NotImplementedError( 'jax.scipy.ndimage.map_coordinates does not yet support mode {}. ' 'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS))) if mode == 'constant': is_valid = lambda index, size: (0 <= index) & (index < size) else: is_valid = lambda index, size: True if order == 0: interp_fun = _nearest_indices_and_weights elif order == 1: interp_fun = _linear_indices_and_weights else: raise NotImplementedError( 'jax.scipy.ndimage.map_coordinates currently requires order<=1') valid_1d_interpolations = [] for coordinate, size in zip(coordinates, input.shape): interp_nodes = interp_fun(coordinate) valid_interp = [] for index, weight in interp_nodes: fixed_index = index_fixer(index, size) valid = is_valid(index, size) valid_interp.append((fixed_index, valid, weight)) valid_1d_interpolations.append(valid_interp) outputs = [] for items in itertools.product(*valid_1d_interpolations): indices, validities, weights = zip(*items) if all(valid is True for valid in validities): # fast path contribution = input[indices] else: all_valid = functools.reduce(operator.and_, validities) contribution = jnp.where(all_valid, input[indices], cval) outputs.append(_nonempty_prod(weights) * contribution) result = _nonempty_sum(outputs) if jnp.issubdtype(input.dtype, jnp.integer): result = _round_half_away_from_zero(result) return result.astype(input.dtype)
def triangular_solve(a, b, left_side=False, lower=False, transpose_a=False, conjugate_a=False, unit_diagonal=False): conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating) singleton = jnp.ndim(b) == jnp.ndim(a) - 1 if singleton: b = jnp.expand_dims(b, -1 if left_side else -2) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) if singleton: out = out[..., 0] if left_side else out[..., 0, :] return out
def _slogdet_qr(a): # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det
def polygamma(n, x): assert jnp.issubdtype(lax.dtype(n), jnp.integer) n, x = _promote_args_inexact("polygamma", n, x) shape = lax.broadcast_shapes(n.shape, x.shape) return _polygamma(jnp.broadcast_to(n, shape), jnp.broadcast_to(x, shape))
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None): if not eigvals_only: raise NotImplementedError( "Calculation of eigenvectors is not implemented") def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count alpha = jnp.asarray(d) beta = jnp.asarray(e) supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) if alpha.dtype != beta.dtype: raise TypeError( "diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes: raise TypeError( "Only float32 and float64 inputs are supported as inputs " "to jax.scipy.linalg.eigh_tridiagonal, got " f"{alpha.dtype} and {beta.dtype}") n = alpha.shape[0] if n <= 1: return jnp.real(alpha) if jnp.issubdtype(alpha.dtype, jnp.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) else: beta_abs = jnp.abs(beta) beta_sq = jnp.square(beta) # Estimate the largest and smallest eigenvalues of T using the Gershgorin # circle theorem. off_diag_abs_row_sum = jnp.concatenate( [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) # Upper bound on 2-norm of T. t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) # Compute the smallest allowed pivot in the Sturm sequence to avoid # overflow. finfo = np.finfo(alpha.dtype) one = np.ones([], dtype=alpha.dtype) safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) abs_tol = finfo.eps * t_norm if tol is not None: abs_tol = jnp.maximum(tol, abs_tol) # In the worst case, when the absolute tolerance is eps*lambda_est_max and # lambda_est_max = -lambda_est_min, we have to take as many bisection steps # as there are bits in the mantissa plus 1. # The proof is left as an exercise to the reader. max_it = finfo.nmant + 1 # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': target_counts = jnp.arange(n, dtype=jnp.int32) elif select == 'i': if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " "implemented") else: raise ValueError("'select must have a value in {'a', 'i', 'v'}.") # Run binary search for all desired eigenvalues in parallel, starting from # the interval lightly wider than the estimated # [lambda_est_min, lambda_est_max]. fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm lower = lambda_est_min - norm_slack - 2 * fudge * pivmin upper = lambda_est_max + norm_slack + fudge * pivmin # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. target_shape = jnp.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) # Start parallel binary searches. def cond(args): i, lower, _, upper = args return jnp.logical_and(jnp.less(i, max_it), jnp.less(abs_tol, jnp.amax(upper - lower))) def body(args): i, lower, mid, upper = args counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) lower = jnp.where(counts <= target_counts, mid, lower) upper = jnp.where(counts > target_counts, mid, upper) mid = 0.5 * (lower + upper) return i + 1, lower, mid, upper _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid
def _to_inexact_type(type): return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
def _round_half_away_from_zero(a): return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a)