def squareform(distances): """ IN: output from `pairwise_distances`, an array of length l = n^2 - n / 2 with entries d(x1, x2, d(x1, 3), ..., d(xn-1 xn)). OUT: a symmetric n x n distance matrix with entries d(x_i, x_j) """ l = distances.shape[0] n = getn(l) out = np.zeros((n, n)) out[np.triu_indices(n, k=1)] out = index_update(out, index[np.triu_indices(n, k=1)], distances) out = out + out.T return out
def get_nuclear_interaction_energy(locations, nuclear_charges, interaction_fn): """Gets nuclear interaction energy for atomic chain. Args: locations: Float numpy array with shape (num_nuclei,), the locations of the nuclei. nuclear_charges: Float numpy array with shape (num_nuclei,), the charges of nuclei. interaction_fn: function takes displacements and returns float numpy array with the same shape of displacements. Returns: Float. Raises: ValueError: If locations.ndim or nuclear_charges.ndim is not 1. """ if locations.ndim != 1: raise ValueError('locations.ndim is expected to be 1 but got %d' % locations.ndim) if nuclear_charges.ndim != 1: raise ValueError( 'nuclear_charges.ndim is expected to be 1 but got %d' % nuclear_charges.ndim) # Convert locations and nuclear_charges to jax.numpy array. locations = jnp.array(locations) nuclear_charges = jnp.array(nuclear_charges) indices_0, indices_1 = jnp.triu_indices(locations.size, k=1) charges_products = nuclear_charges[indices_0] * nuclear_charges[indices_1] return jnp.sum(charges_products * interaction_fn(locations[indices_0] - locations[indices_1]))
def nloglik_chol(X): cov = index_update(jnp.zeros(shape=(p + 1, p + 1)), jnp.triu_indices(p + 1), X).T logdet = 2 + jnp.sum(jnp.diag(cov)) y = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0) sol = jnp.linalg.solve(cov, y) return 0.5 * (N * logdet + jnp.einsum('ij,ij', sol, sol))
def Epot_lj(positions, L: float, M: int): """Potential energy for Lennard-Jones potential in reduced units. In this system of units, epsilon=1 and sigma=2**(-1. / 6.). The function accepts numpy arrays of shape (M, 3) [2D] or (M*3) [1D].""" if (positions.ndim != 2 or positions.shape[1] != 3 ) and not (positions.ndim == 1 and positions.size == M * 3): raise ValueError( "positions must be an Mx3 array or a 1D array that can be reshaped to Mx3!" ) if positions.ndim == 1 and positions.size == M * 3: positions = positions.reshape((M, 3)) # Reshape to Mx3 #sig = 1 / np.power(2, 1 / 6) sig = 1. # Compute all squared distances between pairs delta = positions[:, np.newaxis, :] - positions delta = delta - L * np.around(delta / L, decimals=0) r2 = (delta * delta).sum(axis=2) # r^2 ...squared distances # Take only the upper triangle (combinations of two atoms). indices = np.triu_indices(r2.shape[0], k=1) rm2 = sig * sig / r2[indices] # (sig/r)^2 # Compute the potental energy recycling as many calculations as possible. rm6 = rm2 * rm2 * rm2 # (sig/r)^6 rm12 = rm6 * rm6 # (sig/r)^12 return (rm12 - 2. * rm6).sum()
def dot_interact(concat_features): """Performs feature interaction operation between dense or sparse features. Input tensors represent dense or sparse features. Pre-condition: The tensors have been stacked along dimension 1. Args: concat_features: Array of features with shape [B, n_features, feature_dim]. Returns: activations: Array representing interacted features. """ batch_size = concat_features.shape[0] # Interact features, select upper or lower-triangular portion, and re-shape. xactions = jnp.matmul(concat_features, jnp.transpose(concat_features, [0, 2, 1])) feature_dim = xactions.shape[-1] indices = jnp.array(jnp.triu_indices(feature_dim)) num_elems = indices.shape[1] indices = jnp.tile(indices, [1, batch_size]) indices0 = jnp.reshape( jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]), [1, -1]) indices = tuple(jnp.concatenate((indices0, indices), 0)) activations = xactions[indices] activations = jnp.reshape(activations, [batch_size, -1]) return activations
def _gen_recurrence_mask( l_max: int, is_normalized: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]: """Generates mask for recurrence relation on the remaining entries. The remaining entries are with respect to the diagonal and offdiagonal entries. Args: l_max: see `gen_normalized_legendre`. is_normalized: True if the recurrence mask is used by normalized associated Legendre functions. Returns: Arrays representing the mask used by the recurrence relations. """ # Computes all coefficients. m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1] if is_normalized: c0 = l_mat * l_mat c1 = m_mat * m_mat c2 = 2.0 * l_mat c3 = (l_mat - 1.0) * (l_mat - 1.0) d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) else: d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) d0_mask_indices = jnp.triu_indices(l_max + 1, 1) d1_mask_indices = jnp.triu_indices(l_max + 1, 2) d_zeros = jnp.zeros((l_max + 1, l_max + 1)) d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices]) d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices]) # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. # i = jnp.arange(l_max + 1)[:, None, None] # j = jnp.arange(l_max + 1)[None, :, None] # k = jnp.arange(l_max + 1)[None, None, :] i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1] mask = 1.0 * (i + j - k == 0) d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask) d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask) return (d0_mask_3d, d1_mask_3d)
def triu_matrix_from_v(x, ndim): assert x.shape[-1] == (ndim * (ndim + 1)) // 2 matrix = jnp.zeros(x.shape[:-1] + (ndim, ndim)) idx = jnp.triu_indices(ndim) index_update = lambda x, idx, y: x.at[idx].set(y) for _ in range(x.ndim - 1): index_update = jax.vmap(index_update, in_axes=(0, None, 0)) return index_update(matrix, idx, x)
def jax_vech(X): ''' Half vectorization operator; returns an \frac{(n+1)\times n}{2} vector of the stacked columns of unique items in a symmetric n\times n matrix ''' rix, cix = jnp.triu_indices(len(X)) res = jnp.take(X.T, rix * len(X) + cix) return res
def minimum_distance(x, sdim): n_particles = x.shape[0] // sdim x = x.reshape(-1, sdim) distances = (-x[jnp.newaxis, :, :] + x[:, jnp.newaxis, :])[jnp.triu_indices(n_particles, 1)] return jnp.linalg.norm(distances, axis=1)
def chol_sample(key, d): idx_u = jnp.triu_indices(d) idx_d = jnp.diag_indices(d) L = random.normal(key, (d, d), dtype=jnp.float64) L = ops.index_update(L, idx_u, 0.0) L = ops.index_update(L, idx_d, random.normal(key, (d, ))**2) return L
def jax_invech(v): ''' Inverse half vectorization operator ''' rows = int(jnp.round(.5 * (-1 + jnp.sqrt(1 + 8 * len(v))))) res = jnp.zeros((rows, rows)) res = jax.ops.index_update(res, jnp.triu_indices(rows), v) res = res + res.T - jnp.diag(jnp.diag(res)) return res
def distance(self, x, sdim, L): n_particles = x.shape[0] // sdim x = x.reshape(-1, sdim) dis = -x[jnp.newaxis, :, :] + x[:, jnp.newaxis, :] dis = dis[jnp.triu_indices(n_particles, 1)] dis = L[jnp.newaxis, :] / 2.0 * jnp.sin(jnp.pi * dis / L[jnp.newaxis, :]) return dis
def get_all_pairs_indices(n: int) -> Tuple[Array, Array]: """all indices i, j such that i < j < n""" n_interactions = n * (n - 1) / 2 inds_i, inds_j = np.triu_indices(n, k=1) assert len(inds_i) == n_interactions return inds_i, inds_j
def minimum_distance(x, sdim): """Computes distances between particles using minimum image convention""" n_particles = x.shape[0] // sdim x = x.reshape(-1, sdim) distances = (-x[jnp.newaxis, :, :] + x[:, jnp.newaxis, :])[jnp.triu_indices(n_particles, 1)] distances = jnp.remainder(distances + L / 2.0, L) - L / 2.0 return jnp.linalg.norm(distances, axis=1)
def ll_chol(pars, y): p = y.shape[-1] X, theta = pars[:-p], pars[-p:] sigma = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), X).T sigma = jnp.matmul(sigma, sigma.T) sc = jnp.sqrt(jnp.diag(sigma)) al = jnp.einsum('i,i->i', 1 / sc, theta) capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T))) small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma)) return -(2 + small_phi + capital_phi)
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray=None, sample: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: outputs = {} dim, dtype = inputs["x"].shape[-1], inputs["x"].dtype L = hk.get_parameter("L", shape=(dim, dim), dtype=dtype, init=hk.initializers.RandomNormal(0.01)) U = hk.get_parameter("U", shape=(dim, dim), dtype=dtype, init=hk.initializers.RandomNormal(0.01)) log_d = hk.get_parameter("log_d", shape=(dim,), dtype=dtype, init=jnp.zeros) lower_mask = jnp.ones((dim, dim), dtype=bool) lower_mask = jax.ops.index_update(lower_mask, jnp.triu_indices(dim), False) if self.safe_diag: d = util.proximal_relu(log_d) + 1e-5 log_d = jnp.log(d) def b_init(shape, dtype): x = inputs["x"] if x.ndim == 1: return jnp.zeros(shape, dtype=dtype) # Initialize to the batch mean z = jnp.dot(x, (U*lower_mask.T).T) + x z *= jnp.exp(log_d) z = jnp.dot(z, (L*lower_mask).T) + z b = -jnp.mean(z, axis=0) return b b = hk.get_parameter("b", shape=(dim,), dtype=dtype, init=b_init) # Its way faster to allocate a full matrix for L and U and then mask than it # is to allocate only the lower/upper parts and the reshape. if sample == False: x = inputs["x"] z = jnp.dot(x, (U*lower_mask.T).T) + x z *= jnp.exp(log_d) z = jnp.dot(z, (L*lower_mask).T) + z outputs["x"] = z + b else: z = inputs["x"] @self.auto_batch def invert(z): x = L_solve(L, z - b) x = x*jnp.exp(-log_d) return U_solve(U, x) outputs["x"] = invert(z) outputs["log_det"] = jnp.sum(log_d, axis=-1)*jnp.ones(self.batch_shape) return outputs
def f_bond_length(x): # reshape (n_atoms,3) x = jnp.reshape(x, (self.n_atoms, 3)) # compute all difference z = x[:, None] - x[None, :] # select upper diagonal (LEXIC ORDER) i0 = jnp.triu_indices(self.n_atoms, 1) diff = z[i0] # compute the bond length r = jnp.linalg.norm(diff, axis=1) return r
def registration_individuals(x,y,aar_names,max_iter=10000,aars=None): if aars == None: aars = range(0,len(aar_names)) aar_indices = [y == aar for aar in aars] uti_indices = [np.triu_indices(sum(y == aar),k=1) for aar in aars] def cost_function(x,y): def foo(x,uti): dr = (x[:,uti[0]]-x[:,uti[1]]) return np.sqrt(np.sum(dr*dr,axis=0)).sum() return sum([foo(x[:,aar_indices[aar]],uti_indices[aar]) for aar in range(0,len(aars))]) def transform(param,x): thetas = param[0:len(x)] delta_ps = np.reshape(param[len(x):],(2,len(x))) return np.hstack([np.dot(np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]),x_s)+np.expand_dims(delta_p,1) for theta,delta_p,x_s in zip(thetas,delta_ps.T,x)]) def func(param,x,y): value = cost_function(transform(param,x),y) return value loss = lambda param: func(param,x,y) opt_init, opt_update, get_params = optimizers.adagrad(step_size=1,momentum=0.9) @jit def step(i, opt_state): params = get_params(opt_state) g = grad(loss)(params) return opt_update(i, g, opt_state) net_params = numpy.hstack((numpy.random.uniform(-numpy.pi,numpy.pi,len(x)),numpy.zeros(2*len(x)))) previous_value = loss(net_params) logging.info('Iteration 0: loss = %f'%(previous_value)) opt_state = opt_init(net_params) for i in range(max_iter): opt_state = step(i, opt_state) if i > 0 and i % 10 == 0: net_params = get_params(opt_state) current_value = loss(net_params) logging.info('Iteration %d: loss = %f'%(i+1,current_value)) if numpy.isclose(previous_value/current_value,1): logging.info('Converged after %d iterations'%(i+1)) net_params = get_params(opt_state) return transform(net_params,x) previous_value = current_value logging.warning('Not converged after %d iterations'%(i+1)) net_params = get_params(opt_state) return transform(net_params,x)
def jax_l2_pdist(X): """Computes the pairwise distances between points in X. Args: X: A 2d numpy array, the points. Returns: dm: The pairwise distances between points in X, as a flattened upper-triangular matrix. """ n = X.shape[0] diffs = (X[:, None] - X[None, :])[np.triu_indices(n=n, k=1)] return np.linalg.norm(diffs, axis=1)
def unpack_triu(x, n, hermi=0): R = np.zeros([n, n]) idx = np.triu_indices(n) R = jax.ops.index_update(R, idx, x) if hermi == 0: return R elif hermi == 1: R = R + R.conj().T R = jax.ops.index_mul(R, np.diag_indices(n), 0.5) return R elif hermi == 2: return R - R.conj().T else: raise KeyError
def unflattened(self, x: Array, dimensions: int) -> Array: k = x.shape[-1] sqrt_discriminant = sqrt(1 + 8 * k) i_sqrt_discriminant = int(sqrt_discriminant) if i_sqrt_discriminant != sqrt_discriminant: raise ValueError(f"{k} {sqrt_discriminant}") if i_sqrt_discriminant % 2 != 1: raise ValueError dimensions = (i_sqrt_discriminant - 1) // 2 index = (..., *jnp.triu_indices(dimensions)) empty = jnp.empty(x.shape[:-1] + (dimensions, dimensions), dtype=x.dtype) lower_diagonal = empty.at[index].set(x).T if self.hermitian: lower_diagonal = lower_diagonal.conjugate() return lower_diagonal.at[index].set(x)
def run(manifold, p, k): k, key = random.split(k) tslant = random.normal(key, shape=(p,)) k, key = random.split(k) tcov = random.normal(key, shape=(p, p)) tcov = tcov @ tcov.T tmean = jnp.zeros(shape=(p,)) sn = SkewNormal(loc=tmean, cov=tcov, sl=tslant) k, key = random.split(k) data = sn.sample(key, shape=(N,)) # s_mu = jnp.mean(data, axis=0) # s_cov = jnp.dot((data - s_mu).T, data - s_mu) / N # MLE = jnp.append(jnp.append(s_cov + jnp.outer(s_mu, s_mu), # jnp.array([s_mu]), axis=0), # jnp.array([jnp.append(s_mu, 1)]).T, axis=1) # mle_chol = jnp.linalg.cholesky(MLE) # mle_chol = mle_chol.T[jnp.triu_indices_from(mle_chol)] # data = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0).T fun = jit(lambda x, y: ll(x, y, data)) # gra = jit(grad(fun)) init = (jnp.identity(p), jnp.ones(shape=(p,))) # print(fun(init[0], init[1])) # ll_mle = fun(MLE) res_cg = optimization('rcg', manifold, fun=fun, init=init) res_bfgs = optimization('rlbfgs', manifold, fun=fun, init=init) fun = jit(lambda x, y: ll_chol(x, y, data)) init = (jnp.identity(p)[jnp.triu_indices(p)], jnp.ones(shape=(p,))) # gra = jit(grad(fun)) # ll_mle_chol = fun(mle_chol) res_cho = optimization('chol', fun=fun, init=init) return p, *res_cg, *res_bfgs, *res_cho
#print('Maxiterations reached') break if jnp.isclose(f0, old_f0, rtol=tol): #print('Function not changing') break if (gr_sig_norm <= tol) and (gr_the_norm <= tol): #print('Reached mingradnorm') break old_f0 = f0 toc = time() res.append([p, k, toc - tic, f0]) tic = time() init_chol = jnp.append( jnp.identity(p)[jnp.triu_indices(p)], jnp.ones(shape=(p, ))) fun_chol = jit(lambda x: ll_chol(x, data)) gra_chol = jit(grad(fun_chol)) res_chol = minimize(fun_chol, init_chol, method='cg', jac=gra_chol, tol=tol) toc = time() res[-1] = res[-1] + [res_chol['nit'], toc - tic, res_chol['fun']] df = pd.DataFrame(data=res, columns=[ 'p', 'riem_iter', 'riem_time', 'riem_fun', 'chol_iter', 'chol_time', 'chol_fun' ])
# maxiter=maxiter, mingradnorm=tol, # verbosity=0, logverbosity=logs) optimizers = [ optim_rcg, optim_rsd, #optim_rlbfgs ] RNG, key = random.split(RNG) data, t_cov, t_mu = generate_data(key, p) MLE_rep = t_cov, t_mu if chol: MLE_chol = jnp.linalg.cholesky(t_cov) MLE_chol = jnp.append(MLE_chol.T[jnp.triu_indices(p)], t_mu) def nloglik(X): sigma = X[0] theta = X[1] return ll(sigma, theta, data) if chol: def nloglik_chol(X): return ll_chol(X, data) fun_chol = jit(nloglik_chol) gra_chol = jit(grad(fun_chol)) true_fun_chol = fun_chol(MLE_chol)
def ll_chol(pars, y): p = y.shape[-1] X, theta = pars[:-p], pars[-p:] sigma = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), X).T sigma = jnp.matmul(sigma, sigma.T) return ll(sigma, theta, y)
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: """Generates derivatives of associated Legendre functions of the first kind. Args: p: The 3D array containing the values of associated Legendre functions; the dimensions are in the sequence of order (m), degree (l), and evalution points. x: A vector of type `float32` or `float64` containing the sampled points. is_normalized: True if the associated Legendre functions are normalized. Returns: The 3D array representing the derivatives of associated Legendre functions of the first kind. """ num_m, num_l, num_x = p.shape # p_{l-1}^m. p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :] # p_{l-1}^{m+2}. p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :] # p_{l-1}^{m-2}. p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :] # Derivative computation requires negative orders. if is_normalized: raise NotImplementedError( 'Negative orders for normalization is not implemented yet.') else: if num_l > 1: l_vec = jnp.arange(1, num_l - 1) p_p1 = p[1, 1:num_l - 1, :] coeff = -1.0 / ((l_vec + 1) * l_vec) update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1) if num_l > 2: l_vec = jnp.arange(2, num_l - 1) p_p2 = p[2, 2:num_l - 1, :] coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec) update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2) m_mat, l_mat = jnp.mgrid[:num_m, :num_l] coeff_zeros = jnp.zeros((num_m, num_l)) upper_0_indices = jnp.triu_indices(num_m, 0, num_l) zero_vec = jnp.zeros((num_l, )) a0 = -0.5 / (m_mat - 1.0) a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices]) a0_masked = a0_masked.at[1, :].set(zero_vec) b0 = l_mat + m_mat c0 = a0 * (b0 - 2.0) * (b0 - 1.0) c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices]) c0_masked = c0_masked.at[1, :].set(zero_vec) # p_l^{m-1}. p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) d0 = -0.5 / (m_mat + 1.0) d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices]) e0 = d0 * b0 * (b0 + 1.0) e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices]) # p_l^{m+1}. p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) f0 = b0 * (l_mat - m_mat + 1.0) / 2.0 f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices]) p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l # Special treatment of the singularity at m = 1. if num_m > 1: l_vec = jnp.arange(num_l) g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) if num_l > 2: g0 = g0 - p[2, :, :] p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0) p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, ))) return p_derivative
fun_rep = jit(nloglik) gra_rep = jit(grad(fun_rep)) true_fun_rep = fun_rep(MLE_rep) true_gra_rep = gra_rep(MLE_rep) true_grnorm_rep = man.norm(MLE_rep, true_gra_rep) # print('Reparametrized function on MLE: ', true_fun_rep) # print('Gradient norm of reparametrized function on MLE: ', true_grnorm_rep) init_rep = jnp.identity(p + 1) init_cho = jnp.ones_like(MLE_chol) print('Start conjugate gradient optimization...') result_rcg = optim_rcg.solve(fun_rep, gra_rep, x=init_rep) result_rcg.pprint() print('Start riemannian descent optimization...') result_rsd = optim_rsd.solve(fun_rep, gra_rep, x=init_rep) result_rcg.pprint() print('Start cholesky optimization...') start = time() result_cho = minimize(fun_chol, init_cho, method='cg', jac=gra_chol, tol=tol) cov = index_update( jnp.zeros(shape=(p+1, p+1)), jnp.triu_indices(p+1), res.x).T time_cho = time() - start print("{}\n\t{} iterations in {:.2f} s".format(result_cho['message'], result_cho['nit'], time_cho))
def urt(x, box): distance_matrix = distance(x, box) i, j = np.triu_indices(len(distance_matrix), k=1) return distance_matrix[i, j]
chol_gra = [jnp.linalg.norm(gra_chol(init_chol))] def store(X): chol_fun.append(func_chol(X)) chol_gra.append(jnp.linalg.norm(gra_chol(X))) res = minimize(func_chol, init_chol, method='newton-cg', jac=gra_chol, callback=store, options={'disp': True}) chol, mu_chol = res.x[:-p], res.x[-p:] sig_chol = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), chol).T sig_chol = jnp.einsum('ij,kj', sig_chol, sig_chol) chol_fun = jnp.array(chol_fun) chol_gra = jnp.array(chol_gra) toc = time() ######################################## ######################################## ## Print results: man_2 = SPD(p) print("\n=================\n\tResults:\n") print("Full Riemannian:") print("\tStarting loglik {:.5e}".format(func(startmu, startsig))) print("\tTime spent {:.2f} s".format(res_riem.time))
res = index_update(res, index[i, run, 4], man.dist(result.x, MLE_rep)) res = index_update(res, index[i, run, 5], result.grnorm) res = index_update(res, index[i, run, 6], i) if chol: start = time() result = minimize(fun_chol, init_chol, method='cg', jac=gra_chol, options={'maxiter': maxiter_chol}, tol=tol) # print("{} {} iterations in {:.2f} s".format(res['message'], res['nit'], time() - start)) cov = index_update(jnp.zeros(shape=(p + 1, p + 1)), jnp.triu_indices(p + 1), result.x).T res_cho = index_update(res_cho, index[run, 0], p) res_cho = index_update(res_cho, index[run, 1], time() - start) res_cho = index_update(res_cho, index[run, 2], result['nit']) res_cho = index_update(res_cho, index[run, 3], (result['fun'] - true_fun_chol) / true_fun_chol) res_cho = index_update(res_cho, index[run, 4], man.dist(cov @ cov.T, MLE_rep)) res_cho = index_update(res_cho, index[run, 5], jnp.linalg.norm(result.jac)) res_cho = index_update(res_cho, index[run, 6], 3) columns = [ 'Matrix dimension', 'Time', 'Iterations', 'Function difference', 'Matrix distance', 'Gradient norm', 'Algorithm'