def test_avg_pool(self): X1 = np.ones((4, 2, 3, 2)) X2 = np.ones((3, 2, 3, 2)) _, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1), 'SAME', normalize_edges=False) _, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1), 'SAME', normalize_edges=True) _, apply_fn_stax = stax.ostax.AvgPool((2, 2), (1, 1), 'SAME') out1 = apply_fn((), X1) out2 = apply_fn((), X2) out1_norm = apply_fn_norm((), X1) out2_norm = apply_fn_norm((), X2) out1_stax = apply_fn_stax((), X1) out2_stax = apply_fn_stax((), X2) self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm), True) out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape( (1, 2, 3, 1)) out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape) out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape) self.assertAllClose((out1_unnormalized, out2_unnormalized), (out1, out2), True) ker = kernel_fn(X1, X2) ker_norm = kernel_fn_norm(X1, X2) self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp, True) self.assertAllClose(np.ones_like(ker_norm.var1), ker_norm.var1, True) self.assertAllClose(np.ones_like(ker_norm.var2), ker_norm.var2, True) self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape) self.assertEqual(ker_norm.var1.shape, ker.var1.shape) self.assertEqual(ker_norm.var2.shape, ker.var2.shape) ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3)) ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3)) nngp = np.broadcast_to( ker_unnorm.reshape((1, 1) + ker_unnorm.shape), ker.nngp.shape) var1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var1.shape) var2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var2.shape) self.assertAllClose((nngp, var1, var2), (ker.nngp, ker.var1, ker.var2), True)
def update(self, x, y_true, params, averager=None): # Run forward pass. z1, h1, z2, h2 = self.forward(x, params, return_activations=True) # Compute error for final layer (= gradient of cost w.r.t layer input). e2 = h2 - y_true # gradient through cross entropy loss # Compute gradients of cost w.r.t. parameters. grad_b2 = e2 grad_W2 = np.outer(h1, e2) # Update parameters. self.b2 -= params['lr'] * grad_b2 self.W2 -= params['lr'] * grad_W2 return h2
def update_fn(sample, state): """ :param sample: A new sample. :param state: Current state of the scheme. :return: new state for the scheme. """ mean, m2, n = state n = n + 1 delta_pre = sample - mean mean = mean + delta_pre / n delta_post = sample - mean if diagonal: m2 = m2 + delta_pre * delta_post else: m2 = m2 + np.outer(delta_post, delta_pre) return mean, m2, n
def qmult(key, b): """ QMULT Pre-multiply by random orthogonal matrix. QMULT(A) is Q*A where Q is a random real orthogonal matrix from the Haar distribution, of dimension the number of rows in A. Special case: if A is a scalar then QMULT(A) is the same as QMULT(EYE(A)). Called by RANDSVD. Reference: G.W. Stewart, The efficient generation of random orthogonal matrices with an application to condition estimators, SIAM J. Numer. Anal., 17 (1980), 403-409. """ try: n = b.shape[0] a = b.copy() except AttributeError: n = b a = np.eye(n) d = np.zeros(n) for k in range(n - 2, -1, -1): # Generate random Householder transformation. key, subkey = random.split(key) x = random.normal(subkey, (n - k, )) s = np.linalg.norm(x) # Modification to make sign(0) == 1 sgn = np.sign(x[0]) + float(x[0] == 0) s = sgn * s d = index_update(d, k, -sgn) x = index_update(x, 0, x[0] + s) beta = s * x[0] # Apply the transformation to a y = np.dot(x, a[k:n, :]) a = index_update(a, index[k:n, :], a[k:n, :] - np.outer(x, (y / beta))) # Tidy up signs. for i in range(n - 1): a = index_update(a, index[i, :], d[i] * a[i, :]) # Now randomly change the sign (Gaussian dist) a = index_update(a, index[n - 1, :], a[n - 1, :] * np.sign(random.normal(key, ()))) return a
def testOrderOneDegreeOne(self): """Tests the spherical harmonics of order one and degree one.""" num_theta = 7 num_phi = 8 theta = jnp.linspace(0, math.pi, num_theta) phi = jnp.linspace(0, 2.0 * math.pi, num_phi) expected = -1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * math.pi)) * jnp.outer( jnp.sin(theta), jnp.exp(1j * phi)) sph_harm = spherical_harmonics.SphericalHarmonics(l_max=1, theta=theta, phi=phi) actual = sph_harm.harmonics_nonnegative_order()[1, 1, :, :] np.testing.assert_allclose(jnp.abs(actual), jnp.abs(expected), rtol=1e-8, atol=6e-8)
def var(kappa: float, mu: jnp.ndarray) -> jnp.ndarray: """Compute the variance of the power spherical distribution. Args: kappa: Concentration parameter. mu: Mean direction on the sphere. The dimensionality of the sphere is determined from this paramter. Returns: out: The variance of the power spherical distribution. """ d = mu.size alpha = (d - 1.) / 2. + kappa beta = (d - 1.) / 2. return (2 * alpha / ((alpha + beta)**2 * (alpha + beta + 1.)) * ((beta - alpha) * jnp.outer(mu, mu) + (alpha + beta) * jnp.eye(d)))
def linreg_imputation_model(X, y): ndims = X.shape[1] a = numpyro.sample("a", dist.Normal(0, 0.5)) beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims])) sigma_y = numpyro.sample("sigma_y", dist.Exponential(1)) # X_impute contains imputed data for each feature as a list # X_merged is the observed data filled with imputed values at missing points. X_impute = [None] * ndims X_merged = [None] * ndims for i in range(ndims): # for every feature no_of_missed = int(np.isnan(X[:, i]).sum()) if no_of_missed != 0: # each nan value is associated with a imputed variable of std normal prior. X_impute[i] = numpyro.sample( "X_impute_{}".format(i), dist.Normal(0, 1).expand([no_of_missed]).mask(False)) # merging the observed data with the imputed values. missed_idx = np.nonzero(np.isnan(X[:, i]))[0] X_merged[i] = ops.index_update(X[:, i], missed_idx, X_impute[i]) # if there are no missing values, its just the observed data. else: X_merged[i] = X[:, i] merged_X = jnp.stack(X_merged).T # LKJ is the distribution to model correlation matrices. rho = numpyro.sample("rho", dist.LKJ(ndims, 2)) # correlation matrix sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims])) covariance_x = jnp.outer(sigma_x, sigma_x) * rho # covariance matrix mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims])) numpyro.sample("X_merged", dist.MultivariateNormal(mu_x, covariance_x), obs=merged_X) mu_y = a + merged_X @ beta numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
def weighted_sum(mean, cov, weights): """ Computes mean and variance of a weighted sum of the mvn r.v. Args: mean (np.array): The mean of the MVN. cov (np.array): The covariance of the MVN. weights (np.array): A vector of weights to give the elements. Returns: Tuple[float, float]: The mean and variance of the weighted sum. """ mean_summed_theta = np.dot(mean, weights) outer_x = np.outer(weights, weights) multiplied = cov * outer_x weighted_sum = np.sum(multiplied) return mean_summed_theta, weighted_sum
def linreg_model(X, y): ndims = X.shape[1] a = numpyro.sample("a", dist.Normal(0, 0.5)) beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims])) sigma_y = numpyro.sample("sigma_y", dist.Exponential(1)) # LKJ is the distribution to model correlation matrices. rho = numpyro.sample("rho", dist.LKJ(ndims, 2)) # correlation matrix sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims])) covariance_x = jnp.outer(sigma_x, sigma_x) * rho # covariance matrix mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims])) numpyro.sample("X", dist.MultivariateNormal(mu_x, covariance_x), obs=X) mu_y = a + X @ beta numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
def householder_product(inputs: Array, q_vector: Array) -> Array: """ Args: inputs (Array) : inputs for the householder product (D,) q_vector (Array): vector to be multiplied (D,) Returns: outputs (Array) : outputs after the householder product """ # norm for q_vector squared_norm = jnp.sum(q_vector**2) # inner product temp = jnp.dot(inputs, q_vector) # outer product temp = jnp.outer(temp, (2.0 / squared_norm) * q_vector).squeeze() # update output = inputs - temp return output
def update_fun(step, grads, state): """Apply a step of the optimzier.""" del step # Unused. params, grad_seq = state # Update gradient history. grad_seq = append_to_sequence(grad_seq, grads) # Compute normalized gram matrix. gram = innerprod(grad_seq, grad_seq) grad_norm = norms(grad_seq) gram /= (jnp.outer(grad_norm, grad_norm) + 1e-6) # Compute update terms. attn_weights = jnp.dot(stax.softmax(gram, axis=0), theta_gram) attn_term = jnp.tensordot(attn_weights, grad_seq, axes=1) grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1) params -= (grad_term + attn_term) return (params, grad_seq)
def nngp_ntk_fn(nngp, q11, q22, ntk=None): """Simple Gauss-Hermite quadrature routine.""" xs, ws = quad_points grid = np.outer(ws, ws) x = xs.reshape((xs.shape[0],) + (1,) * (nngp.ndim + 1)) y = xs.reshape((1, xs.shape[0]) + (1,) * nngp.ndim) xy_axes = (0, 1) nngp = np.expand_dims(nngp, xy_axes) q11, q22 = np.expand_dims(q11, xy_axes), np.expand_dims(q22, xy_axes) def integrate(f): fvals = f(_sqrt(2 * q11) * x) * f( nngp / _sqrt(q11 / 2, 1e-30) * x + _sqrt( 2*(q22 - nngp**2/q11)) * y) return np.tensordot(grid, fvals, (xy_axes, xy_axes)) / np.pi if ntk is not None: ntk *= integrate(df) nngp = integrate(fn) return nngp, ntk
def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwargs): r""" Computes the linear matter power spectrum. Parameters ---------- k: array_like Wave number in h Mpc^{-1} a: array_like, optional Scale factor (def: 1.0) transfer_fn: transfer_fn(cosmo, k, **kwargs) Transfer function Returns ------- pk: array_like Linear matter power spectrum at the specified scale and scale factor. """ k = np.atleast_1d(k) a = np.atleast_1d(a) g = bkgrd.growth_factor(cosmo, a) t = transfer_fn(cosmo, k, **kwargs) pknorm = cosmo.sigma8**2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs) if k.ndim == 1: pk = np.outer(primordial_matter_power(cosmo, k) * t**2, g**2) else: pk = primordial_matter_power(cosmo, k) * t**2 * g**2 # Apply normalisation pk = pk * pknorm return pk.squeeze()
def _calc_vars(theta: np.array): """ Calculate the mean and variance of the posterior. mu = theta(1, 5) * mask(5, 2) = (1, 2) vector s_1, s_2 = theta(n, 5) * mask(5, 2) * theta(5, n) = (n,) scalar rho = theta(1, 5) * mask(5, 1) = (1, 1) vector Sigma = [[s_1 ** 2, rho * s_1 * s_2], [rho * s_1 * s_2, s_2 ** 2]] Sigma = outer([s_1, s_2], [s_1, s_2]) * elementwise_mult rho """ # have to do it this way to allow vectorization s_vec = np.square(theta[2:4]) rho = np.tanh(theta[5]) rho_matrix = np.eye(2) + rho * np.eye(2)[::-1] # off-diagonal rho mu = theta[:2] Sigma = np.outer(s_vec, s_vec) * rho_matrix return mu, Sigma
def update(state: WelfordAlgorithmState, value: np.DeviceArray) -> WelfordAlgorithmState: """Update the M2 matrix using the new value. Parameters ---------- state: The current state of the Welford Algorithm position: jax.numpy.DeviceArray, shape (1,) The new sample (typically position of the chain) used to update m2 """ mean, m2, sample_size = state sample_size = sample_size + 1 delta = value - mean mean = mean + delta / sample_size updated_delta = value - mean if is_diagonal_matrix: new_m2 = m2 + delta * updated_delta else: new_m2 = m2 + np.outer(updated_delta, delta) return WelfordAlgorithmState(mean, new_m2, sample_size)
def positional_encoding(seq_len, embed_dim, timescale=10000): """ Returns positional encoding values. Assumes seq dimensions are (batch, seq_len, word_space) Output shape: seq_len x embed_dim """ if embed_dim % 2 != 0: raise ValueError("Embedding dimension must be even") positions = jnp.arange(seq_len) i = jnp.arange(embed_dim // 2) angular_frequencies = 1 / jnp.power(timescale, 2 * i / embed_dim) angles = jnp.outer(positions, angular_frequencies) cosine = jnp.cos(angles) # seq_len, embed_dim // 2 sine = jnp.sin(angles) # seq_len, embed_dim // 2 pos_enc = jnp.concatenate([cosine, sine], axis=1) return pos_enc
def update(state: WelfordAlgorithmState, value: np.DeviceArray) -> WelfordAlgorithmState: """Update the M2 matrix using the new value. Arguments: ---------- state: The current state of the Welford Algorithm value: jax.numpy.DeviceArray, shape (1,) The new sample used to update m2 """ mean, m2, count = state count = count + 1 delta = value - mean mean = mean + delta / count updated_delta = value - mean if is_diagonal_matrix: m2 = m2 + delta * updated_delta else: m2 = m2 + np.outer(delta, updated_delta) return WelfordAlgorithmState(mean, m2, count)
def house_rightmult(A, v, beta): """ Given the m x n matrix A and the length-n vector v with normalization beta such that P = I - beta v otimes dag(v) is the Householder matrix that reflects about v, compute AP. Parameters ---------- A: array_like, shape(M, N) Matrix to be multiplied by H. v: array_like, shape(N). Householder vector. beta: float Householder normalization. Returns ------- C = AP """ C = A - jnp.outer(A @ v, beta * dag(v)) return C
def mkcovdiag_ASD(len_sc, rho, nxcirc, wvec=None, wwnrm=None): # Eigenvalues of ASD covariance (as diagonalized in Fourier domain) # # [cdiag,dcdiag,ddcdiag] = mkcovdiag_ASD(rho,l,nxcirc,wvecsq) # # Compute discrete ASD (RBF kernel) eigenspectrum using frequencies in [0, nxcirc]. # See mkCov_ASD_factored for more info # # INPUT (all python 1d lists!): # len - length scale of ASD kernel (determines smoothness) # rho - maximal prior variance ("overall scale") # nxcirc - number of coefficients to consider for circular boundary # wvec - vector of freq for DFT # wwnrm - vector of freq for DFT (normalized) # # OUTPUT: # cdiag [nxcirc x 1] - vector of eigenvalues of C for frequencies in w # # Note: nxcirc = nx corresponds to having a circular boundary # Compute diagonal of ASD covariance matrix if wvec is not None: wvecsq = np.square(wvec) const = np.square(2 * np.pi / nxcirc) # constant ww = wvecsq * const # effective frequency vector elif wwnrm is not None: ww = wwnrm else: print( "please provide either wvec or a normalized wvec into this function" ) cdiag = np.squeeze( np.sqrt(2 * np.pi) * rho * len_sc * np.exp(-.5 * np.outer(ww, np.square(len_sc)))) return cdiag
def searchphase(y, x): y_t = jnp.outer(y, TPS) x_t = jnp.tile(x[:,None], (1, testing_phases)) snr_t = 10. * jnp.log10(getpower(y_t) / getpower(y_t - x_t)) return TPS[jnp.argmax(snr_t)]
def _gyration_tensor(positions): n = positions.shape[0] S = np.zeros((3, 3)) for r in positions: S += np.outer(r, r) return S / n
def outer(self, tensor_in_1, tensor_in_2): return jnp.outer(tensor_in_1, tensor_in_2)
def affine_transform(dist_params, scale, shift, value_transform=None): """ implements the "Categorical Algorithm" from https://arxiv.org/abs/1707.06887 """ # check inputs chex.assert_rank([dist_params['logits'], scale, shift], [2, {0, 1}, {0, 1}]) p = jax.nn.softmax(dist_params['logits']) batch_size = p.shape[0] if isscalar(scale): scale = jnp.full(shape=(batch_size, ), fill_value=jnp.squeeze(scale)) if isscalar(shift): shift = jnp.full(shape=(batch_size, ), fill_value=jnp.squeeze(shift)) chex.assert_shape(p, (batch_size, self.num_bins)) chex.assert_shape([scale, shift], (batch_size, )) if value_transform is None: f = f_inv = lambda x: x else: f, f_inv = value_transform # variable names correspond to those defined in: https://arxiv.org/abs/1707.06887 z = self.__atoms Vmin, Vmax, Δz = z[0], z[-1], z[1] - z[0] Tz = f(jax.vmap(jnp.add)(jnp.outer(scale, f_inv(z)), shift)) Tz = jnp.clip(Tz, Vmin, Vmax) # keep values in valid range chex.assert_shape(Tz, (batch_size, self.num_bins)) b = (Tz - Vmin) / Δz # float in [0, num_bins - 1] l = jnp.floor(b).astype( 'int32') # noqa: E741 # int in {0, 1, ..., num_bins - 1} u = jnp.ceil(b).astype('int32') # int in {0, 1, ..., num_bins - 1} chex.assert_shape([p, b, l, u], (batch_size, self.num_bins)) m = jnp.zeros_like(p) i = jnp.expand_dims(jnp.arange(batch_size), axis=1) # batch index m = jax.ops.index_add(m, (i, l), p * (u - b), indices_are_sorted=True) m = jax.ops.index_add(m, (i, u), p * (b - l), indices_are_sorted=True) m = jax.ops.index_add(m, (i, l), p * (l == u), indices_are_sorted=True) # chex.assert_tree_all_close(jnp.sum(m, axis=1), jnp.ones(batch_size), rtol=1e-6) # # The above index trickery is equivalent to: # m_alt = onp.zeros((batch_size, self.num_bins)) # for i in range(batch_size): # for j in range(self.num_bins): # if l[i, j] == u[i, j]: # m_alt[i, l[i, j]] += p[i, j] # don't split if b[i, j] is an integer # else: # m_alt[i, l[i, j]] += p[i, j] * (u[i, j] - b[i, j]) # m_alt[i, u[i, j]] += p[i, j] * (b[i, j] - l[i, j]) # chex.assert_tree_all_close(m, m_alt, rtol=1e-6) return {'logits': jnp.log(jnp.maximum(m, 1e-16))}
optimizers = [ optim_rcg, optim_rsd, #optim_rlbfgs ] RNG, key = random.split(RNG) t_cov, t_mu = orig_man.rand(key) RNG, key = random.split(RNG) data = random.multivariate_normal(key, mean=t_mu, cov=t_cov, shape=(N, )) s_mu = jnp.mean(data, axis=0) s_cov = jnp.dot((data - s_mu).T, data - s_mu) / N MLE_rep = jnp.append(jnp.append(s_cov + jnp.outer(s_mu, s_mu), jnp.array([s_mu]), axis=0), jnp.array([jnp.append(s_mu, 1)]).T, axis=1) if chol: MLE_chol = jnp.linalg.cholesky(MLE_rep) MLE_chol = MLE_chol.T[~(MLE_chol.T == 0.)].ravel() def nloglik(X): y = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0) datapart = jnp.trace(jnp.linalg.solve(X, jnp.matmul(y, y.T))) return 0.5 * (N * jnp.linalg.slogdet(X)[1] + datapart) if chol:
def compute_OBC_energy_vectorized( distance_matrix, radii, scales, charges, offset=0.009, screening=138.935484, surface_tension=28.3919551, solvent_dielectric=78.5, solute_dielectric=1.0, ): """Compute GBSA-OBC energy from a distance matrix""" N = len(radii) #print(type(distance_matrix)) eye = np.eye(N, dtype=distance_matrix.dtype) #print(type(eye)) r = distance_matrix + eye # so I don't have divide-by-zero nonsense or1 = radii.reshape((N, 1)) - offset or2 = radii.reshape((1, N)) - offset sr2 = scales.reshape((1, N)) * or2 L = np.maximum(or1, abs(r - sr2)) U = r + sr2 I = step(r + sr2 - or1) * 0.5 * (1 / L - 1 / U + 0.25 * (r - sr2**2 / r) * (1 / (U**2) - 1 / (L**2)) + 0.5 * np.log(L / U) / r) I -= np.diag(np.diag(I)) I = np.sum(I, axis=1) # okay, next compute born radii offset_radius = radii - offset psi = I * offset_radius psi_coefficient = 0.8 psi2_coefficient = 0 psi3_coefficient = 2.909125 psi_term = (psi_coefficient * psi) + (psi2_coefficient * psi**2) + (psi3_coefficient * psi**3) B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii) # finally, compute the three energy terms E = 0.0 # single particle E += np.sum(surface_tension * (radii + 0.14)**2 * (radii / B)**6) E += np.sum(-0.5 * screening * (1 / solute_dielectric - 1 / solvent_dielectric) * charges**2 / B) # particle pair f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B)))) charge_products = np.outer(charges, charges) E += np.sum( np.triu(-screening * (1 / solute_dielectric - 1 / solvent_dielectric) * charge_products / f, k=1)) return E
def _matmul_impl(state: Carry, data: MatmulData) -> Tuple[Carry, Array]: (Fp, Vp, Yp) = state Vn, Pn, Yn = data Fn = _pdot(Pn, Fp + jnp.outer(Vp, Yp)) return (Fn, Vn, Yn), Fn
def _solve_impl(state: Carry, data: Data) -> Tuple[Carry, Array]: Fp, Wp, Zp = state Un, Wn, Pn, Yn = data Fn = _pdot(Pn, Fp + jnp.outer(Wp, Zp)) Zn = Yn - Un @ Fn return (Fn, Wn, Zn), Zn
def __init__(self, n_samp, seed, outdtype, pardtype, holomorphic): self.dtype = outdtype self.target = { "a": jnp.array([[[0j], [0j]]], dtype=jnp.complex128), "b": jnp.array(0, dtype=jnp.float64), "c": jnp.array(0j, dtype=jnp.complex64), } if pardtype is None: # mixed precision as above pass else: self.target = jax.tree_map( lambda x: astype_unsafe(x, pardtype), self.target, ) k = jax.random.PRNGKey(seed) k1, k2, k3, k4, k5 = jax.random.split(k, 5) self.samples = jax.random.normal(k1, (n_samp, 2)) self.w = jax.random.normal(k2, (n_samp,), self.dtype).astype( self.dtype ) # TODO remove astype once its fixed in jax self.params = tree_random_normal_like(k3, self.target) self.v = tree_random_normal_like(k4, self.target) self.grad = tree_random_normal_like(k5, self.target) if holomorphic: @partial(jax.vmap, in_axes=(None, 0)) def f(params, x): return astype_unsafe( params["a"][0][0][0] * x[0] + params["b"] * x[1] + params["c"] * (x[0] * x[1]) + jnp.sin(x[1] * params["a"][0][1][0]) * jnp.cos(x[0] * params["b"] + 1j) * params["c"], self.dtype, ) else: @partial(jax.vmap, in_axes=(None, 0)) def f(params, x): return astype_unsafe( params["a"][0][0][0].conjugate() * x[0] + params["b"] * x[1] + params["c"] * (x[0] * x[1]) + jnp.sin(x[1] * params["a"][0][1][0]) * jnp.cos(x[0] * params["b"].conjugate() + 1j) * params["c"].conjugate(), self.dtype, ) self.f = f self.params_real_flat = tree_toreal_flat(self.params) self.grad_real_flat = tree_toreal_flat(self.grad) self.v_real_flat = tree_toreal_flat(self.v) self.ok_real = self.grads_real(self.params_real_flat, self.samples) self.okmean_real = self.ok_real.mean(axis=0) self.dok_real = self.ok_real - self.okmean_real self.S_real = ( self.dok_real.conjugate().transpose() @ self.dok_real / n_samp ).real self.scale = jnp.sqrt(self.S_real.diagonal()) self.S_real_scaled = self.S_real / (jnp.outer(self.scale, self.scale))
def normal_expected_stats(normal): mu, sigma = normal_nat_to_std(*normal) t1 = np.outer(mu, mu) + sigma t2 = mu return t1, t2
def niw_std_to_nat(mu, kappa, psi, nu): n1 = kappa * np.outer(mu, mu) + psi n2 = kappa * mu n3 = kappa n4 = nu + psi.shape[0] + 2 return n1, n2, n3, n4