def eigh_jvp_rule(primals, tangents, lower): # Derivative for eigh in the simplest case of distinct eigenvalues. # This is classic nondegenerate perurbation theory, but also see # https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf # The general solution treating the case of degenerate eigenvalues is # considerably more complicated. Ambitious readers may refer to the general # methods below or refer to degenerate perturbation theory in physics. # https://www.win.tue.nl/analysis/reports/rana06-33.pdf and # https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf a, = primals a_dot, = tangents v, w_real = eigh_p.bind(symmetrize(a), lower=lower) # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) eye_n = jnp.eye(a.shape[-1], dtype=a.dtype) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. Fmat = jnp.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) dv = dot(v, jnp.multiply(Fmat, vdag_adot_v)) dw = jnp.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) return (v, w_real), (dv, dw)
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") k = s.shape[-1] Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: return (s,), (ds,) F = 1 / (jnp.square(s_dim) - jnp.square(_T(s_dim)) + jnp.eye(k, dtype=A.dtype)) F = F - jnp.eye(k, dtype=A.dtype) dSS = s_dim * dS SdS = _T(s_dim) * dS dU = jnp.matmul(U, F * (dSS + _T(dSS))) dV = jnp.matmul(V, F * (SdS + _T(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim if n > m: dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _T(dV))
def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _slogdet_qr(a): # Implementation of slogdet using QR decomposition. One reason we might prefer # QR decomposition is that it is more amenable to a fast batched # implementation on TPU because of the lack of row pivoting. if jnp.issubdtype(lax.dtype(a), jnp.complexfloating): raise NotImplementedError("slogdet method='qr' not implemented for complex " "inputs") n = a.shape[-1] a, taus = lax_linalg.geqrf(a) # The determinant of a triangular matrix is the product of its diagonal # elements. We are working in log space, so we compute the magnitude as the # the trace of the log-absolute values, and we compute the sign separately. log_abs_det = jnp.trace(jnp.log(jnp.abs(a)), axis1=-2, axis2=-1) sign_diag = jnp.prod(jnp.sign(jnp.diagonal(a, axis1=-2, axis2=-1)), axis=-1) # The determinant of a Householder reflector is -1. So whenever we actually # made a reflection (tau != 0), multiply the result by -1. sign_taus = jnp.prod(jnp.where(taus[..., :(n-1)] != 0, -1, 1), axis=-1).astype(sign_diag.dtype) return sign_diag * sign_taus, log_abs_det
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv): A, = primals dA, = tangents s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True) if compute_uv and full_matrices: # TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf raise NotImplementedError( "Singular value decomposition JVP not implemented for full matrices") Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = jnp.matmul(jnp.matmul(Ut, dA), V) ds = jnp.real(jnp.diagonal(dS, 0, -2, -1)) if not compute_uv: return (s,), (ds,) s_diffs = jnp.square(s_dim) - jnp.square(_T(s_dim)) s_diffs_zeros = jnp.eye(s.shape[-1], dtype=A.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim * dS # dS.dot(jnp.diag(s)) SdS = _T(s_dim) * dS # jnp.diag(s).dot(dS) s_zeros = jnp.ones((), dtype=A.dtype) * (s == 0.) s_inv = 1 / (s + s_zeros) - s_zeros s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv) dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat dU = jnp.matmul(U, F * (dSS + _H(dSS)) + dUdV_diag) dV = jnp.matmul(V, F * (SdS + _H(SdS))) m, n = A.shape[-2:] if m > n: dU = dU + jnp.matmul(jnp.eye(m, dtype=A.dtype) - jnp.matmul(U, Ut), jnp.matmul(dA, V)) / s_dim if n > m: dV = dV + jnp.matmul(jnp.eye(n, dtype=A.dtype) - jnp.matmul(V, Vt), jnp.matmul(_H(dA), U)) / s_dim return (s, U, Vt), (ds, dU, _H(dV))
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], )) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate( (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x