def phif(s): pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1) pmag = np.sqrt(pmagsq) phipartial = np.reciprocal(pmag) singular = np.any(np.equal(-s, lambar), axis=-1) phipartial = np.where(singular, 0., phipartial) phi = phipartial - np.reciprocal(trust_radius) return phi
def minimum_volume_enclosing_ellipsoid(points, tol, init_u=None, return_u=False): """ Performs the algorithm of MINIMUM VOLUME ENCLOSING ELLIPSOIDS NIMA MOSHTAGH psuedo-code here: https://stackoverflow.com/questions/1768197/bounding-ellipse Args: points: [N, D] """ N, D = points.shape Q = jnp.concatenate([points, jnp.ones([N, 1])], axis=1) # N,D+1 def body(state): (count, err, u) = state V = Q.T @ jnp.diag(u) @ Q # D+1, D+1 # g[i] = Q[i,j].V^-1_jk.Q[i,k] g = vmap(lambda q: q @ jnp.linalg.solve(V, q))(Q) # difference # jnp.diag(Q @ jnp.linalg.solve(V, Q.T)) j = jnp.argmax(g) g_max = g[j] step_size = \ (g_max - D - 1) / ((D + 1) * (g_max - 1)) search_direction = jnp.where(jnp.arange(N) == j, 1. - u, -u) new_u = u + step_size * search_direction # new_u = (1. - step_size)*u new_u = jnp.where( jnp.arange(N) == j, u + step_size * (1. - u), u * (1. - step_size)) new_err = jnp.linalg.norm(u - new_u) return (count + 1, new_err, new_u) if init_u is None: init_u = jnp.ones(N) / N (count, err, u) = while_loop(lambda state: state[1] > tol * jnp.linalg.norm(init_u), body, (0, jnp.inf, init_u)) U = jnp.diag(u) PU = (points.T @ u) # D, N A = jnp.reciprocal(D) * jnp.linalg.pinv(points.T @ U @ points - PU[:, None] @ PU[None, :]) c = PU W, Q, Vh = jnp.linalg.svd(A) radii = jnp.reciprocal(jnp.sqrt(Q)) rotation = Vh.conj().T if return_u: return c, radii, rotation, u return c, radii, rotation
def init_fn(z, rng, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None): """ :param z: Initial position of the integrator. :param jax.random.PRNGKey rng: Random key to be used as the source of randomness. :param float step_size: Initial step size. :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``, inverse of mass matrix will be an identity matrix with size is decided by the argument `mass_matrix_size`. :param int mass_matrix_size: Size of the mass matrix. :return: initial state of the adapt scheme. """ rng, rng_ss = random.split(rng) if inverse_mass_matrix is None: assert mass_matrix_size is not None if dense_mass: inverse_mass_matrix = np.identity(mass_matrix_size) else: inverse_mass_matrix = np.ones(mass_matrix_size) mass_matrix_sqrt = inverse_mass_matrix else: if dense_mass: mass_matrix_sqrt = cholesky_inverse(inverse_mass_matrix) else: mass_matrix_sqrt = np.sqrt(np.reciprocal(inverse_mass_matrix)) if adapt_step_size: step_size = find_reasonable_step_size(inverse_mass_matrix, z, rng_ss, step_size) ss_state = ss_init(np.log(10 * step_size)) mm_state = mm_init(inverse_mass_matrix.shape[-1]) window_idx = 0 return AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt, ss_state, mm_state, window_idx, rng)
def quantized_softmax(a): # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing # intermediate values. Note this differs from the log-domain # implementation of softmax used above. quant_hparams = softmax_hparams.quant_hparams fp_quant_config = QuantOps.FloatQuant(is_scaled=False, fp_spec=quant_hparams.prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config, bounds=None) a = quant_ops.to_quantized(a, dtype=dtype) # Note that the max of a quantized vector is necessarily also quantized to # the same precision since the max of a vector must be an existing element # of the vector, so we don't need to explicitly insert a quantization # operator to the output of the max reduction. a_max = jnp.max(a, axis=norm_dims, keepdims=True) a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype) a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype) sum_exp_quantized_reduction = quantization.quantized_sum( a_exp, axis=norm_dims, keepdims=True, prec=quant_hparams.reduction_prec) sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction, dtype=dtype) inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp), dtype=dtype) a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype) return a_softmax.astype(dtype)
def test_welford_covariance(jitted, diagonal, regularize): with optional(jitted, disable_jit()), optional(jitted, control_flow_prims_disabled()): np.random.seed(0) loc = np.random.randn(3) a = np.random.randn(3, 3) target_cov = np.matmul(a, a.T) x = np.random.multivariate_normal(loc, target_cov, size=(2000, )) x = device_put(x) @jit def get_cov(x): wc_init, wc_update, wc_final = welford_covariance( diagonal=diagonal) wc_state = wc_init(3) wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val), wc_state) cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize) return cov, cov_inv_sqrt cov, cov_inv_sqrt = get_cov(x) if diagonal: diag_cov = jnp.diagonal(target_cov) assert_allclose(cov, diag_cov, rtol=0.06) assert_allclose(cov_inv_sqrt, jnp.sqrt(jnp.reciprocal(diag_cov)), rtol=0.06) else: assert_allclose(cov, target_cov, rtol=0.06) assert_allclose(cov_inv_sqrt, jnp.linalg.cholesky(jnp.linalg.inv(cov)), rtol=0.06)
def final_fn(state, regularize=False): """ :param state: Current state of the scheme. :param bool regularize: Whether to adjust diagonal for numerical stability. :return: a triple of estimated covariance, the square root of precision, and the inverse of that square root. """ mean, m2, n = state # XXX it is not necessary to check for the case n=1 cov = m2 / (n - 1) if regularize: # Regularization from Stan scaled_cov = (n / (n + 5)) * cov shrinkage = 1e-3 * (5 / (n + 5)) if diagonal: cov = scaled_cov + shrinkage else: cov = scaled_cov + shrinkage * jnp.identity(mean.shape[0]) if jnp.ndim(cov) == 2: # copy the implementation of distributions.util.cholesky_of_inverse here tril_inv = jnp.swapaxes( jnp.linalg.cholesky(cov[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1) identity = jnp.identity(cov.shape[-1]) cov_inv_sqrt = solve_triangular(tril_inv, identity, lower=True) else: tril_inv = jnp.sqrt(cov) cov_inv_sqrt = jnp.reciprocal(tril_inv) return cov, cov_inv_sqrt, tril_inv
def pow_right(y, z, ildj_): # x ** y = z # x = f^-1(z) = z ** (1 / y) # grad(f^-1)(z) = 1 / y * z ** (1 / y - 1) # log(grad(f^-1)(z)) = (1 / y - 1)log(z) - log(y) y_inv = np.reciprocal(y) return lax.pow(z, y_inv), ildj_ + (y_inv - 1.) * np.log(z) - np.log(y)
def _sample_momentum(unpack_fn, inverse_mass_matrix, rng): if inverse_mass_matrix.ndim == 1: r = dist.norm(0., np.sqrt( np.reciprocal(inverse_mass_matrix))).rvs(random_state=rng) return unpack_fn(r) elif inverse_mass_matrix.ndim == 2: raise NotImplementedError
def restricted_mp2(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=0): nelectrons = int(jnp.sum(nuclear_charges)) - charge ndocc = nelectrons // 2 E_scf, C, eps, G = restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=deriv_order, return_aux_data=True) nvirt = G.shape[0] - ndocc nbf = G.shape[0] G = partial_tei_transformation(G, C[:,:ndocc],C[:,ndocc:],C[:,:ndocc],C[:,ndocc:]) # Create tensor dim (occ,vir,occ,vir) of all possible orbital energy denominators # Partial tei transformation is super efficient, it is this part that is bad. eps_occ, eps_vir = eps[:ndocc], eps[ndocc:] e_denom = jnp.reciprocal(eps_occ.reshape(-1, 1, 1, 1) - eps_vir.reshape(-1, 1, 1) + eps_occ.reshape(-1, 1) - eps_vir) # Tensor contraction algo #mp2_correlation = jnp.einsum('iajb,iajb,iajb->', G, G, e_denom) +\ # jnp.einsum('iajb,iajb,iajb->', G - jnp.transpose(G, (0,3,2,1)), G, e_denom) #mp2_total_energy = mp2_correlation + E_scf #return E_scf + mp2_correlation # Loop algo (lower memory, but tei transform is the memory bottleneck) # Create all combinations of four loop variables to make XLA compilation easier indices = cartesian_product(jnp.arange(ndocc),jnp.arange(ndocc),jnp.arange(nvirt),jnp.arange(nvirt)) with loops.Scope() as s: s.mp2_correlation = 0. for idx in s.range(indices.shape[0]): i,j,a,b = indices[idx] s.mp2_correlation += G[i, a, j, b] * (2 * G[i, a, j, b] - G[i, b, j, a]) * e_denom[i,a,j,b] return E_scf + s.mp2_correlation
def _sample_momentum(unpack_fn, inverse_mass_matrix, rng): if inverse_mass_matrix.ndim == 1: r = dist.Normal(0., np.sqrt( np.reciprocal(inverse_mass_matrix))).sample(rng) return unpack_fn(r) elif inverse_mass_matrix.ndim == 2: raise NotImplementedError
def body(state): p_k = -(state.H_k @ state.g_k) line_search_results = line_search(value_and_grad, state.x_k, p_k, old_fval=state.f_k, gfk=state.g_k, maxiter=ls_maxiter) state = state._replace(nfev=state.nfev + line_search_results.nfev, ngev=state.ngev + line_search_results.ngev, failed=line_search_results.failed, ls_status=line_search_results.status) s_k = line_search_results.a_k * p_k x_kp1 = state.x_k + s_k f_kp1 = line_search_results.f_k g_kp1 = line_search_results.g_k # print(g_kp1) y_k = g_kp1 - state.g_k rho_k = jnp.reciprocal(y_k @ s_k) sy_k = s_k[:, None] * y_k[None, :] w = jnp.eye(d) - rho_k * sy_k H_kp1 = jnp.where(jnp.isfinite(rho_k), jnp.linalg.multi_dot([w, state.H_k, w.T]) + rho_k * s_k[:, None] * s_k[None, :], state.H_k) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol state = state._replace(converged=converged, k=state.k + 1, x_k=x_kp1, f_k=f_kp1, g_k=g_kp1, H_k=H_kp1 ) return state
def quantized_layernorm(x): prec = hparams.quant_hparams.prec fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant, bounds=None) def to_quantized(x): return quant_ops.to_quantized(x, dtype=dtype) # If epsilon is too small to represent in the quantized format, we set it # to the minimal representative non-zero value to avoid the possibility of # dividing by zero. fp_bounds = quantization.fp_cast.get_bounds( prec.exp_min, prec.exp_max, prec.sig_bits) epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound) quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype)) # If the reciprocal of the quantized number of features is too small to # represent in the quantized format, we set it to the minimal # representative nonzero value so that the mean and variance are not # trivially 0. num_features_quantized = to_quantized( jnp.array(num_features, dtype=dtype)) num_features_recip_quantized = to_quantized( jnp.reciprocal(num_features_quantized)) num_features_recip_quantized = jax.lax.cond( jax.lax.eq(num_features_recip_quantized, 0.0), lambda _: quantized_epsilon, lambda _: num_features_recip_quantized, None) x_quantized = to_quantized(x) x_sum_quantized_reduction = quantization.quantized_sum( x_quantized, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sum = to_quantized(x_sum_quantized_reduction) mean = to_quantized(x_sum * num_features_recip_quantized) x_minus_mean = to_quantized(x - mean) x_sq = to_quantized(lax.square(x_minus_mean)) x_sq_sum_quantized_reduction = quantization.quantized_sum( x_sq, axis=-1, keepdims=True, prec=hparams.quant_hparams.reduction_prec) x_sq_sum = to_quantized(x_sq_sum_quantized_reduction) var = to_quantized(x_sq_sum * num_features_recip_quantized) # Prevent division by zero. var_plus_epsilon = to_quantized(var + quantized_epsilon) mul = to_quantized(lax.rsqrt(var_plus_epsilon)) if self.use_scale: quantized_scale_param = to_quantized(scale_param) mul = to_quantized(mul * quantized_scale_param) y = to_quantized(x_minus_mean * mul) if self.use_bias: quantized_bias_param = to_quantized(bias_param) y = to_quantized(y + quantized_bias_param) return y.astype(self.dtype)
def body_fun(state: LBFGSResults): # find search direction p_k = _two_loop_recursion(state) # line search ls_results = line_search( f=fun, xk=state.x_k, pk=p_k, old_fval=state.f_k, gfk=state.g_k, maxiter=maxls, ) # evaluate at next iterate s_k = ls_results.a_k * p_k x_kp1 = state.x_k + s_k f_kp1 = ls_results.f_k g_kp1 = ls_results.g_k y_k = g_kp1 - state.g_k rho_k_inv = jnp.real(_dot(y_k, s_k)) rho_k = jnp.reciprocal(rho_k_inv) gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k)) # replacements for next iteration status = 0 status = jnp.where(state.f_k - f_kp1 < ftol, 4, status) status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore status = jnp.where(state.k >= maxiter, 1, status) # type: ignore status = jnp.where(ls_results.failed, 5, status) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol # TODO(jakevdp): use a fixed-point procedure rather than type-casting? state = state._replace( converged=converged, failed=(status > 0) & (~converged), k=state.k + 1, nfev=state.nfev + ls_results.nfev, ngev=state.ngev + ls_results.ngev, x_k=x_kp1.astype(state.x_k.dtype), f_k=f_kp1.astype(state.f_k.dtype), g_k=g_kp1.astype(state.g_k.dtype), s_history=_update_history_vectors(history=state.s_history, new=s_k), y_history=_update_history_vectors(history=state.y_history, new=y_k), rho_history=_update_history_scalars(history=state.rho_history, new=rho_k), gamma=gamma, status=jnp.where(converged, 0, status), ls_status=ls_results.status, ) return state
def gaussian_euclidean_metric( inverse_mass_matrix: np.DeviceArray, ) -> Tuple[Callable, Callable]: """Emulate dynamics on an Euclidean Manifold [1]_ for vanilla Hamiltonian Monte Carlo with a standard gaussian as the conditional probability density :math:`\\pi(momentum|position)`. References ---------- .. [1]: Betancourt, Michael. "A general metric for Riemannian manifold Hamiltonian Monte Carlo." International Conference on Geometric Science of Information. Springer, Berlin, Heidelberg, 2013. """ ndim = np.ndim(inverse_mass_matrix) shape = np.shape(inverse_mass_matrix)[:1] if ndim == 1: # diagonal mass matrix mass_matrix_sqrt = np.sqrt(np.reciprocal(inverse_mass_matrix)) @jax.jit def momentum_generator(rng_key: jax.random.PRNGKey) -> np.DeviceArray: std = jax.random.normal(rng_key, shape) p = np.multiply(std, mass_matrix_sqrt) return p @jax.jit def kinetic_energy(momentum: np.DeviceArray) -> float: velocity = np.multiply(inverse_mass_matrix, momentum) return 0.5 * np.dot(velocity, momentum) return momentum_generator, kinetic_energy elif ndim == 2: mass_matrix_sqrt = cholesky_triangular(inverse_mass_matrix) @jax.jit def momentum_generator(rng_key: jax.random.PRNGKey) -> np.DeviceArray: std = jax.random.normal(rng_key, shape) p = np.dot(std, mass_matrix_sqrt) return p @jax.jit def kinetic_energy(momentum: np.DeviceArray) -> float: velocity = np.matmul(inverse_mass_matrix, momentum) return 0.5 * np.dot(velocity, momentum) return momentum_generator, kinetic_energy else: raise ValueError( "The mass matrix has the wrong number of dimensions: " + "expected 1 or 2, got {}.".format(np.ndim(inverse_mass_matrix)) )
def integer_pow_inverse(z, *, y): """Inverse for `integer_pow_p` primitive.""" if y == 0: raise ValueError('Cannot invert raising to a value to the 0-th power.') elif y == 1: return z elif y == -1: return np.reciprocal(z) elif y == 2: return np.sqrt(z) return lax.pow(z, 1. / y)
def precision_matrix(self): # We use "Woodbury matrix identity" to take advantage of low rank form:: # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) # where :math:`C` is the capacitance matrix. Wt_Dinv = (np.swapaxes(self.cov_factor, -1, -2) / np.expand_dims(self.cov_diag, axis=-2)) A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True) # TODO: find a better solution to create a diagonal matrix inverse_cov_diag = np.reciprocal(self.cov_diag) diag_embed = inverse_cov_diag[..., np.newaxis] * np.identity(self.loc.shape[-1]) return diag_embed - np.matmul(np.swapaxes(A, -1, -2), A)
def scaleSqFromModelParsSingleMu(A, e, M, W, etas, binCenters1, good_idx): coeffe1 = binCenters1[..., 1] coeffM1 = binCenters1[..., 0] #term1 = A[good_idx[0]]-e[good_idx[0]]*coeffe1+M[good_idx[0]]*coeffM1 + W[good_idx[0]]*np.abs(coeffM1) term1 = (A[good_idx[0]] - 1.) + np.reciprocal( 1 + e[good_idx[0]] * coeffe1) + M[good_idx[0]] * coeffM1 + W[good_idx[0]] * np.abs(coeffM1) scaleSq = np.square(1. - term1) return scaleSq
def kl_divergence(p, q): # From https://arxiv.org/abs/1605.06197 Formula (12) a, b = p.concentration1, p.concentration0 alpha, beta = q.concentration1, q.concentration0 b_reciprocal = jnp.reciprocal(b) a_b = a * b t1 = (alpha / a - 1) * (jnp.euler_gamma + digamma(b) + b_reciprocal) t2 = jnp.log(a_b) + betaln(alpha, beta) + (b_reciprocal - 1) a_ = jnp.expand_dims(a, -1) b_ = jnp.expand_dims(b, -1) a_b_ = jnp.expand_dims(a_b, -1) m = jnp.arange(1, p.KL_KUMARASWAMY_BETA_TAYLOR_ORDER + 1) t3 = (beta - 1) * b * (jnp.exp(betaln(m / a_, b_)) / (m + a_b_)).sum(-1) return t1 + t2 + t3
def MSAWeight_PB(msa): gap_idx = msa.abc.charmap['-'] q = msa.abc.q ax = msa.ax (N, L) = ax.shape ## step 1: get counts: c = np.sum(msa.ax_1hot, axis=0) # set gap counts to 0 c = index_update(c, index[:, gap_idx], 0) # get N x L array with count value for corresponding residue in alignment # first, get N x L "column id" array (convenient for vmap) # col_id[n,i] = i col_id = np.int16(np.tensordot(np.ones(N), np.arange(L), axes=0)) # ax_c[n, i] = c[i, ax[n,i]] ax_c = Get_Henikoff_Counts_Residue(col_id, ax, c) ## step 2: get number of unique characters in each column r = np.float32(np.sum(np.array(c > 0), axis=1)) # transform r from Lx1 array to NxL array, where r2[n,i] = r[i]) # will allow for easy elementwise operations with ax_c r2 = np.tensordot(np.ones(N), r, axes=0) ## step 3: get ungapped seq lengths nongap = np.array(ax != gap_idx) l = np.float32(np.sum(nongap, axis=1)) ## step 4: calculate unnormalized weights # get array of main terms in Henikoff sum #wgt_un[n,i] = 1 / (r_[i] * c[i, ax[n,i] ]) wgt_un = np.reciprocal(np.multiply(ax_c, r2)) # set all terms involving gap to zero wgt_un = np.nan_to_num(np.multiply(wgt_un, nongap)) # sum accoss all positions to get prelim unnormalized weight for each sequence wgt_un = np.sum(wgt_un, axis=1) # divide by gapless sequence length wgt_un = np.divide(wgt_un, l) # step 4: Normalize sequence wieghts wgt = (wgt_un * np.float32(N)) / np.sum(wgt_un) msa.wgt = wgt return
def sigmaSqFromModelParsSingleMu(a, b, c, d, etas, binCenters1, good_idx): #compute sigma from physics parameters pt2 = binCenters1[..., 2] L2 = binCenters1[..., 3] #corr = binCenters1[...,4] invpt2 = binCenters1[..., 4] sigmaSq = a[good_idx[0]] * L2 + c[good_idx[0]] * pt2 * np.square(L2) + b[ good_idx[0]] * L2 * np.reciprocal(1 + d[good_idx[0]] * invpt2 / L2) #sigmaSq = a[good_idx[0]]*L2 + c[good_idx[0]]*pt2*np.square(L2) + corr return sigmaSq
def var_exp(self, freqs, Y_obs, sigma, amp, mu, gamma): """ Computes variational expectation Args: freqs: [Nf] Y_obs: [Nf] sigma: [Nf] amp: [Nf] mu: [M] gamma: [M] Returns: scalar """ f = self._phase_basis(self.freqs) # Nf,M Nf = freqs.size Sigma_real = jnp.square(sigma[:Nf]) Sigma_imag = jnp.square(sigma[Nf:]) Yreal = Y_obs[:Nf] Yimag = Y_obs[Nf:] a = jnp.reciprocal(Sigma_real) b = jnp.reciprocal(Sigma_imag) constant = -Nf * jnp.log(2. * jnp.pi) logdet = -jnp.sum(jnp.log(sigma)) phi = jnp.dot(f, mu) theta = jnp.dot(jnp.square(f), jnp.square(gamma)) exp_cos = jnp.exp(-0.5 * theta) * jnp.cos(phi) exp_sin = jnp.exp(-0.5 * theta) * jnp.sin(phi) exp_cos2 = 0.5 * (jnp.exp(-2. * theta) * jnp.cos(2. * phi) + 1.) negtwo_maha = a * (jnp.square(Yreal) - 2. * amp * Yreal * exp_cos) \ + b * (jnp.square(Yimag) - 2. * amp * Yimag * exp_sin) \ + ((a - b) * exp_cos2 + b) * jnp.square(amp) return constant + logdet - 0.5 * jnp.sum(negtwo_maha)
def init_fn(z_info, rng_key, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None): """ :param IntegratorState z_info: The initial integrator state. :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness. :param float step_size: Initial step size. :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``, inverse of mass matrix will be an identity matrix with size is decided by the argument `mass_matrix_size`. :param int mass_matrix_size: Size of the mass matrix. :return: initial state of the adapt scheme. """ rng_key, rng_key_ss = random.split(rng_key) if inverse_mass_matrix is None: assert mass_matrix_size is not None if dense_mass: inverse_mass_matrix = jnp.identity(mass_matrix_size) else: inverse_mass_matrix = jnp.ones(mass_matrix_size) mass_matrix_sqrt = mass_matrix_sqrt_inv = inverse_mass_matrix else: if dense_mass: mass_matrix_sqrt_inv = jnp.swapaxes( jnp.linalg.cholesky( inverse_mass_matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1) identity = jnp.identity(inverse_mass_matrix.shape[-1]) mass_matrix_sqrt = solve_triangular(mass_matrix_sqrt_inv, identity, lower=True) else: mass_matrix_sqrt_inv = jnp.sqrt(inverse_mass_matrix) mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt_inv) if adapt_step_size: step_size = find_reasonable_step_size(step_size, inverse_mass_matrix, z_info, rng_key_ss) ss_state = ss_init(jnp.log(10 * step_size)) mm_state = mm_init(inverse_mass_matrix.shape[-1]) window_idx = 0 return HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv, ss_state, mm_state, window_idx, rng_key)
def unquantized_layernorm(x): num_features_recip = jnp.reciprocal(num_features) x_sum = jnp.sum(x, axis=-1, keepdims=True) mean = x_sum * num_features_recip x_minus_mean = x - mean x_sq = lax.square(x_minus_mean) x_sq_sum = jnp.sum(x_sq, axis=-1, keepdims=True) var = x_sq_sum * num_features_recip var_plus_epsilon = var + self.epsilon mul = lax.rsqrt(var_plus_epsilon) if self.use_scale: mul = mul * scale_param y = x_minus_mean * mul if self.use_bias: y = y + bias_param return y.astype(self.dtype)
def body(state): (i, _, key, done, _) = state key, accept_key, sample_key, select_key = random.split(key, 4) k = random.categorical(select_key, log_p) mu_k = mu[k, :] radii_k = radii[k, :] rotation_k = rotation[k, :, :] u_test = sample_ellipsoid(sample_key, mu_k, radii_k, rotation_k, unit_cube_constraint=unit_cube_constraint) inside = vmap(lambda mu, radii, rotation: point_in_ellipsoid( u_test, mu, radii, rotation))(mu, radii, rotation) n_intersect = jnp.sum(inside) done = (random.uniform(accept_key) < jnp.reciprocal(n_intersect)) return (i + 1, k, key, done, u_test)
def ellipsoid_params(C): """ If C satisfies the sectional inequality, (x - mu)^T C (x - mu) <= 1 then this returns the radius and rotation matrix of the ellipsoid. Args: C: [D,D] Returns: radii [D] rotation [D,D] """ W, Q, Vh = jnp.linalg.svd(C) radii = jnp.reciprocal(jnp.sqrt(Q)) radii = jnp.where(jnp.isnan(radii), 0., radii) rotation = Vh.conj().T return radii, rotation
def debug_mvee(): import pylab as plt n = random.normal(random.PRNGKey(0), (10000,2)) n = n /jnp.linalg.norm(n, axis=1, keepdims=True) angle = jnp.arctan2(n[:,1], n[:,0]) plt.hist(angle, bins=100) plt.show() N = 120 D = 2 points = random.uniform(random.PRNGKey(0), (N, D)) from jax import disable_jit with disable_jit(): center, radii, rotation = minimum_volume_enclosing_ellipsoid(points, 0.01) plt.hist(jnp.linalg.norm((rotation.T @ (points.T - center[:, None])) / radii[:, None], axis=0)) plt.show() print(center, radii, rotation) plt.scatter(points[:, 0], points[:, 1]) theta = jnp.linspace(0., jnp.pi*2, 100) ellipsis = center[:, None] + rotation @ jnp.stack([radii[0]*jnp.cos(theta), radii[1]*jnp.sin(theta)], axis=0) plt.plot(ellipsis[0,:], ellipsis[1,:]) for i in range(1000): y = sample_ellipsoid(random.PRNGKey(i), center, radii, rotation) plt.scatter(y[0], y[1]) C = jnp.linalg.pinv(jnp.cov(points, rowvar=False, bias=True)) p = (N - D - 1)/N def q(p): return p + p**2/(4.*(D-1)) C = C / q(p) c = jnp.mean(points, axis=0) W, Q, Vh = jnp.linalg.svd(C) radii = jnp.reciprocal(jnp.sqrt(Q)) rotation = Vh.conj().T ellipsis = c[:, None] + rotation @ jnp.stack([radii[0] * jnp.cos(theta), radii[1] * jnp.sin(theta)], axis=0) plt.plot(ellipsis[0, :], ellipsis[1, :]) plt.show()
def body_fun(state): p_k = -_dot(state.H_k, state.g_k) line_search_results = line_search( fun, state.x_k, p_k, old_fval=state.f_k, old_old_fval=state.old_old_fval, gfk=state.g_k, maxiter=line_search_maxiter, ) state = state._replace( nfev=state.nfev + line_search_results.nfev, ngev=state.ngev + line_search_results.ngev, failed=line_search_results.failed, line_search_status=line_search_results.status, ) s_k = line_search_results.a_k * p_k x_kp1 = state.x_k + s_k f_kp1 = line_search_results.f_k g_kp1 = line_search_results.g_k y_k = g_kp1 - state.g_k rho_k = jnp.reciprocal(_dot(y_k, s_k)) sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :] w = jnp.eye(d, dtype=rho_k.dtype) - rho_k * sy_k H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w) + rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :]) H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol state = state._replace( converged=converged, k=state.k + 1, x_k=x_kp1, f_k=f_kp1, g_k=g_kp1, H_k=H_kp1, old_old_fval=state.f_k, ) return state
def final_fn(state, regularize=False): """ :param state: Current state of the scheme. :param bool regularize: Whether to adjust diagonal for numerical stability. :return: a pair of estimated covariance and the square root of precision. """ mean, m2, n = state # XXX it is not necessary to check for the case n=1 cov = m2 / (n - 1) if regularize: # Regularization from Stan scaled_cov = (n / (n + 5)) * cov shrinkage = 1e-3 * (5 / (n + 5)) if diagonal: cov = scaled_cov + shrinkage else: cov = scaled_cov + shrinkage * np.identity(mean.shape[0]) if np.ndim(cov) == 2: cov_inv_sqrt = cholesky_inverse(cov) else: cov_inv_sqrt = np.sqrt(np.reciprocal(cov)) return cov, cov_inv_sqrt
def sample_ellipsoid(key, center, radii, rotation, unit_cube_constraint=False): """ Sample uniformly inside an ellipsoid. When unit_cube_constraint=True then during the sampling when a random radius is chosen, the radius is constrained. u(t) = R @ (t * n) + c u(t) == 1 1-c = t * R@n t = (1 - c)/R@n take minimum t satisfying this likewise for zero intersection Args: key: center: [D] radii: [D] rotation: [D,D] Returns: [D] """ direction_key, radii_key = random.split(key, 2) direction = random.normal(direction_key, shape=radii.shape) if unit_cube_constraint: direction = direction / jnp.linalg.norm(direction) R = rotation * radii D = R @ direction t0 = -center / D t1 = jnp.reciprocal(D) + t0 t0 = jnp.where(t0 < 0., jnp.inf, t0) t1 = jnp.where(t1 < 0., jnp.inf, t1) t = jnp.minimum(jnp.min(t0), jnp.min(t1)) t = jnp.minimum(t, 1.) return jnp.exp( jnp.log(random.uniform(radii_key, minval=0., maxval=t)) / radii.size) * D + center log_norm = jnp.log(jnp.linalg.norm(direction)) log_radius = jnp.log(random.uniform(radii_key)) / radii.size # x = direction * (radius/norm) x = direction * jnp.exp(log_radius - log_norm) return circle_to_ellipsoid(x, center, radii, rotation)
def body(state): (key, i, u_test, x_test, log_L_test) = state key, uniform_key, beta_key = random.split(key, 3) # [M] U_scale = random.uniform(uniform_key, shape=spawn_point_U.shape, minval=t_L, maxval=t_R) t_shrink = random.beta(beta_key, live_points_U.shape[0], 1)**jnp.reciprocal(spawn_point_U.size) u_test_white = U_scale / t_shrink # y_j = # = dx + sum_i p_i * u_i # = dx + R @ u # x_i = x0_i + R_ij u_j if whiten: u_test = L @ (spawn_point_U + R @ u_test_white) + u_mean else: u_test = u_test_white u_test = jnp.clip(u_test, 0., 1.) x_test = prior_transform(u_test) log_L_test = loglikelihood_from_constrained(**x_test) return (key, i + 1, u_test, x_test, log_L_test)