Example #1
0
def squareform(distances):
    """
    IN: output from `pairwise_distances`, an array of length l = n^2 - n / 2 with entries d(x1, x2, d(x1, 3), ..., d(xn-1 xn)).
    OUT: a symmetric n x n distance matrix with entries d(x_i, x_j)
    """
    l = distances.shape[0]
    n = getn(l)
    out = np.zeros((n, n))
    out[np.triu_indices(n, k=1)]

    out = index_update(out, index[np.triu_indices(n, k=1)], distances)
    out = out + out.T
    return out
Example #2
0
def get_nuclear_interaction_energy(locations, nuclear_charges, interaction_fn):
    """Gets nuclear interaction energy for atomic chain.

  Args:
    locations: Float numpy array with shape (num_nuclei,),
        the locations of the nuclei.
    nuclear_charges: Float numpy array with shape (num_nuclei,),
        the charges of nuclei.
    interaction_fn: function takes displacements and returns
        float numpy array with the same shape of displacements.

  Returns:
    Float.

  Raises:
    ValueError: If locations.ndim or nuclear_charges.ndim is not 1.
  """
    if locations.ndim != 1:
        raise ValueError('locations.ndim is expected to be 1 but got %d' %
                         locations.ndim)
    if nuclear_charges.ndim != 1:
        raise ValueError(
            'nuclear_charges.ndim is expected to be 1 but got %d' %
            nuclear_charges.ndim)
    # Convert locations and nuclear_charges to jax.numpy array.
    locations = jnp.array(locations)
    nuclear_charges = jnp.array(nuclear_charges)
    indices_0, indices_1 = jnp.triu_indices(locations.size, k=1)
    charges_products = nuclear_charges[indices_0] * nuclear_charges[indices_1]
    return jnp.sum(charges_products *
                   interaction_fn(locations[indices_0] - locations[indices_1]))
Example #3
0
 def nloglik_chol(X):
     cov = index_update(jnp.zeros(shape=(p + 1, p + 1)),
                        jnp.triu_indices(p + 1), X).T
     logdet = 2 + jnp.sum(jnp.diag(cov))
     y = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0)
     sol = jnp.linalg.solve(cov, y)
     return 0.5 * (N * logdet + jnp.einsum('ij,ij', sol, sol))
Example #4
0
File: sim.py Project: GenosW/NSSC2
def Epot_lj(positions, L: float, M: int):
    """Potential energy for Lennard-Jones potential in reduced units.
        In this system of units, epsilon=1 and sigma=2**(-1. / 6.). 
        
        The function accepts numpy arrays of shape (M, 3) [2D] or (M*3) [1D]."""
    if (positions.ndim != 2 or positions.shape[1] != 3
        ) and not (positions.ndim == 1 and positions.size == M * 3):
        raise ValueError(
            "positions must be an Mx3 array or a 1D array that can be reshaped to Mx3!"
        )
    if positions.ndim == 1 and positions.size == M * 3:
        positions = positions.reshape((M, 3))  # Reshape to Mx3
    #sig = 1 / np.power(2, 1 / 6)
    sig = 1.

    # Compute all squared distances between pairs
    delta = positions[:, np.newaxis, :] - positions
    delta = delta - L * np.around(delta / L, decimals=0)
    r2 = (delta * delta).sum(axis=2)  # r^2 ...squared distances

    # Take only the upper triangle (combinations of two atoms).
    indices = np.triu_indices(r2.shape[0], k=1)
    rm2 = sig * sig / r2[indices]  # (sig/r)^2
    # Compute the potental energy recycling as many calculations as possible.
    rm6 = rm2 * rm2 * rm2  # (sig/r)^6
    rm12 = rm6 * rm6  # (sig/r)^12
    return (rm12 - 2. * rm6).sum()
Example #5
0
def dot_interact(concat_features):
    """Performs feature interaction operation between dense or sparse features.
  Input tensors represent dense or sparse features.
  Pre-condition: The tensors have been stacked along dimension 1.
  Args:
    concat_features: Array of features with shape [B, n_features, feature_dim].
  Returns:
    activations: Array representing interacted features.
  """
    batch_size = concat_features.shape[0]

    # Interact features, select upper or lower-triangular portion, and re-shape.
    xactions = jnp.matmul(concat_features,
                          jnp.transpose(concat_features, [0, 2, 1]))
    feature_dim = xactions.shape[-1]

    indices = jnp.array(jnp.triu_indices(feature_dim))
    num_elems = indices.shape[1]
    indices = jnp.tile(indices, [1, batch_size])
    indices0 = jnp.reshape(
        jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
        [1, -1])
    indices = tuple(jnp.concatenate((indices0, indices), 0))
    activations = xactions[indices]
    activations = jnp.reshape(activations, [batch_size, -1])
    return activations
Example #6
0
def _gen_recurrence_mask(
        l_max: int,
        is_normalized: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Generates mask for recurrence relation on the remaining entries.

  The remaining entries are with respect to the diagonal and offdiagonal
  entries.

  Args:
    l_max: see `gen_normalized_legendre`.
    is_normalized: True if the recurrence mask is used by normalized associated
      Legendre functions.

  Returns:
    Arrays representing the mask used by the recurrence relations.
  """

    # Computes all coefficients.
    m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1]
    if is_normalized:
        c0 = l_mat * l_mat
        c1 = m_mat * m_mat
        c2 = 2.0 * l_mat
        c3 = (l_mat - 1.0) * (l_mat - 1.0)
        d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
        d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
    else:
        d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
        d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)

    d0_mask_indices = jnp.triu_indices(l_max + 1, 1)
    d1_mask_indices = jnp.triu_indices(l_max + 1, 2)
    d_zeros = jnp.zeros((l_max + 1, l_max + 1))
    d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices])
    d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices])

    # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
    # i = jnp.arange(l_max + 1)[:, None, None]
    # j = jnp.arange(l_max + 1)[None, :, None]
    # k = jnp.arange(l_max + 1)[None, None, :]
    i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1]
    mask = 1.0 * (i + j - k == 0)

    d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask)
    d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask)

    return (d0_mask_3d, d1_mask_3d)
Example #7
0
def triu_matrix_from_v(x, ndim):
    assert x.shape[-1] == (ndim * (ndim + 1)) // 2
    matrix = jnp.zeros(x.shape[:-1] + (ndim, ndim))
    idx = jnp.triu_indices(ndim)
    index_update = lambda x, idx, y: x.at[idx].set(y)
    for _ in range(x.ndim - 1):
        index_update = jax.vmap(index_update, in_axes=(0, None, 0))
    return index_update(matrix, idx, x)
Example #8
0
def jax_vech(X):
    '''
    Half vectorization operator; returns an \frac{(n+1)\times n}{2} vector of
    the stacked columns of unique items in a symmetric  n\times n matrix
    '''
    rix, cix = jnp.triu_indices(len(X))
    res = jnp.take(X.T, rix * len(X) + cix)
    return res
Example #9
0
File: CS.py Project: netket/netket
def minimum_distance(x, sdim):

    n_particles = x.shape[0] // sdim
    x = x.reshape(-1, sdim)

    distances = (-x[jnp.newaxis, :, :] +
                 x[:, jnp.newaxis, :])[jnp.triu_indices(n_particles, 1)]
    return jnp.linalg.norm(distances, axis=1)
Example #10
0
def chol_sample(key, d):
    idx_u = jnp.triu_indices(d)
    idx_d = jnp.diag_indices(d)

    L = random.normal(key, (d, d), dtype=jnp.float64)
    L = ops.index_update(L, idx_u, 0.0)
    L = ops.index_update(L, idx_d, random.normal(key, (d, ))**2)

    return L
Example #11
0
def jax_invech(v):
    '''
    Inverse half vectorization operator
    '''
    rows = int(jnp.round(.5 * (-1 + jnp.sqrt(1 + 8 * len(v)))))
    res = jnp.zeros((rows, rows))
    res = jax.ops.index_update(res, jnp.triu_indices(rows), v)
    res = res + res.T - jnp.diag(jnp.diag(res))
    return res
Example #12
0
    def distance(self, x, sdim, L):
        n_particles = x.shape[0] // sdim
        x = x.reshape(-1, sdim)

        dis = -x[jnp.newaxis, :, :] + x[:, jnp.newaxis, :]

        dis = dis[jnp.triu_indices(n_particles, 1)]
        dis = L[jnp.newaxis, :] / 2.0 * jnp.sin(jnp.pi * dis / L[jnp.newaxis, :])
        return dis
Example #13
0
def get_all_pairs_indices(n: int) -> Tuple[Array, Array]:
    """all indices i, j such that i < j < n"""
    n_interactions = n * (n - 1) / 2

    inds_i, inds_j = np.triu_indices(n, k=1)

    assert len(inds_i) == n_interactions

    return inds_i, inds_j
Example #14
0
def minimum_distance(x, sdim):
    """Computes distances between particles using minimum image convention"""
    n_particles = x.shape[0] // sdim
    x = x.reshape(-1, sdim)

    distances = (-x[jnp.newaxis, :, :] +
                 x[:, jnp.newaxis, :])[jnp.triu_indices(n_particles, 1)]
    distances = jnp.remainder(distances + L / 2.0, L) - L / 2.0

    return jnp.linalg.norm(distances, axis=1)
Example #15
0
def ll_chol(pars, y):
    p = y.shape[-1]
    X, theta = pars[:-p], pars[-p:]
    sigma = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), X).T
    sigma = jnp.matmul(sigma, sigma.T)
    sc = jnp.sqrt(jnp.diag(sigma))
    al = jnp.einsum('i,i->i', 1 / sc, theta)
    capital_phi = jnp.sum(norm.logcdf(jnp.matmul(al, y.T)))
    small_phi = jnp.sum(mvn.logpdf(y, mean=jnp.zeros(p), cov=sigma))
    return -(2 + small_phi + capital_phi)
Example #16
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    outputs = {}

    dim, dtype = inputs["x"].shape[-1], inputs["x"].dtype

    L     = hk.get_parameter("L", shape=(dim, dim), dtype=dtype, init=hk.initializers.RandomNormal(0.01))
    U     = hk.get_parameter("U", shape=(dim, dim), dtype=dtype, init=hk.initializers.RandomNormal(0.01))
    log_d = hk.get_parameter("log_d", shape=(dim,), dtype=dtype, init=jnp.zeros)
    lower_mask = jnp.ones((dim, dim), dtype=bool)
    lower_mask = jax.ops.index_update(lower_mask, jnp.triu_indices(dim), False)

    if self.safe_diag:
      d = util.proximal_relu(log_d) + 1e-5
      log_d = jnp.log(d)

    def b_init(shape, dtype):
      x = inputs["x"]
      if x.ndim == 1:
        return jnp.zeros(shape, dtype=dtype)

      # Initialize to the batch mean
      z = jnp.dot(x, (U*lower_mask.T).T) + x
      z *= jnp.exp(log_d)
      z = jnp.dot(z, (L*lower_mask).T) + z
      b = -jnp.mean(z, axis=0)
      return b

    b = hk.get_parameter("b", shape=(dim,), dtype=dtype, init=b_init)

    # Its way faster to allocate a full matrix for L and U and then mask than it
    # is to allocate only the lower/upper parts and the reshape.
    if sample == False:
      x = inputs["x"]
      z = jnp.dot(x, (U*lower_mask.T).T) + x
      z *= jnp.exp(log_d)
      z = jnp.dot(z, (L*lower_mask).T) + z
      outputs["x"] = z + b
    else:
      z = inputs["x"]

      @self.auto_batch
      def invert(z):
        x = L_solve(L, z - b)
        x = x*jnp.exp(-log_d)
        return U_solve(U, x)

      outputs["x"] = invert(z)

    outputs["log_det"] = jnp.sum(log_d, axis=-1)*jnp.ones(self.batch_shape)
    return outputs
 def f_bond_length(x):
     #     reshape (n_atoms,3)
     x = jnp.reshape(x, (self.n_atoms, 3))
     #     compute all difference
     z = x[:, None] - x[None, :]
     #     select upper diagonal (LEXIC ORDER)
     i0 = jnp.triu_indices(self.n_atoms, 1)
     diff = z[i0]
     #     compute the bond length
     r = jnp.linalg.norm(diff, axis=1)
     return r
Example #18
0
def registration_individuals(x,y,aar_names,max_iter=10000,aars=None):

  if aars == None:
      aars = range(0,len(aar_names))
    
  aar_indices = [y == aar for aar in aars]
  uti_indices = [np.triu_indices(sum(y == aar),k=1) for aar in aars]
    
  def cost_function(x,y):
    def foo(x,uti): 
      dr = (x[:,uti[0]]-x[:,uti[1]])
      return np.sqrt(np.sum(dr*dr,axis=0)).sum()
    return sum([foo(x[:,aar_indices[aar]],uti_indices[aar]) for aar in range(0,len(aars))])

  def transform(param,x):
    thetas = param[0:len(x)]
    delta_ps = np.reshape(param[len(x):],(2,len(x)))
    return np.hstack([np.dot(np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]),x_s)+np.expand_dims(delta_p,1) for theta,delta_p,x_s in zip(thetas,delta_ps.T,x)])

  def func(param,x,y):
    value = cost_function(transform(param,x),y)
    return value

  loss = lambda param: func(param,x,y)
    
  opt_init, opt_update, get_params = optimizers.adagrad(step_size=1,momentum=0.9)

  @jit
  def step(i, opt_state):
    params = get_params(opt_state)
    g = grad(loss)(params)
    return opt_update(i, g, opt_state)

  net_params = numpy.hstack((numpy.random.uniform(-numpy.pi,numpy.pi,len(x)),numpy.zeros(2*len(x))))
  previous_value = loss(net_params)
  logging.info('Iteration 0: loss = %f'%(previous_value))
  opt_state = opt_init(net_params)
  for i in range(max_iter):
    opt_state = step(i, opt_state)
    if i > 0 and i % 10 == 0:
      net_params = get_params(opt_state)
      current_value = loss(net_params)
      logging.info('Iteration %d: loss = %f'%(i+1,current_value))

      if numpy.isclose(previous_value/current_value,1):
          logging.info('Converged after %d iterations'%(i+1))
          net_params = get_params(opt_state)
          return transform(net_params,x)

      previous_value = current_value

  logging.warning('Not converged after %d iterations'%(i+1))
  net_params = get_params(opt_state)
  return transform(net_params,x)
Example #19
0
def jax_l2_pdist(X):
  """Computes the pairwise distances between points in X.
  
  Args:
    X: A 2d numpy array, the points.
  Returns:
    dm: The pairwise distances between points in X, as a flattened 
      upper-triangular matrix.
  """
  n = X.shape[0]
  diffs = (X[:, None] - X[None, :])[np.triu_indices(n=n, k=1)]
  return np.linalg.norm(diffs, axis=1)
Example #20
0
def unpack_triu(x, n, hermi=0):
    R = np.zeros([n, n])
    idx = np.triu_indices(n)
    R = jax.ops.index_update(R, idx, x)
    if hermi == 0:
        return R
    elif hermi == 1:
        R = R + R.conj().T
        R = jax.ops.index_mul(R, np.diag_indices(n), 0.5)
        return R
    elif hermi == 2:
        return R - R.conj().T
    else:
        raise KeyError
Example #21
0
 def unflattened(self, x: Array, dimensions: int) -> Array:
     k = x.shape[-1]
     sqrt_discriminant = sqrt(1 + 8 * k)
     i_sqrt_discriminant = int(sqrt_discriminant)
     if i_sqrt_discriminant != sqrt_discriminant:
         raise ValueError(f"{k} {sqrt_discriminant}")
     if i_sqrt_discriminant % 2 != 1:
         raise ValueError
     dimensions = (i_sqrt_discriminant - 1) // 2
     index = (..., *jnp.triu_indices(dimensions))
     empty = jnp.empty(x.shape[:-1] + (dimensions, dimensions),
                       dtype=x.dtype)
     lower_diagonal = empty.at[index].set(x).T
     if self.hermitian:
         lower_diagonal = lower_diagonal.conjugate()
     return lower_diagonal.at[index].set(x)
Example #22
0
def run(manifold, p, k):
    k, key = random.split(k)
    tslant = random.normal(key, shape=(p,))
    
    k, key = random.split(k)
    tcov = random.normal(key, shape=(p, p))
    tcov = tcov @ tcov.T

    tmean = jnp.zeros(shape=(p,))

    sn = SkewNormal(loc=tmean, cov=tcov, sl=tslant)

    k, key = random.split(k)
    data = sn.sample(key, shape=(N,))
    
    # s_mu = jnp.mean(data, axis=0)
    # s_cov = jnp.dot((data - s_mu).T, data - s_mu) / N
    # MLE = jnp.append(jnp.append(s_cov + jnp.outer(s_mu, s_mu),
    #                             jnp.array([s_mu]), axis=0),
    #                  jnp.array([jnp.append(s_mu, 1)]).T, axis=1)
    # mle_chol = jnp.linalg.cholesky(MLE)
    # mle_chol = mle_chol.T[jnp.triu_indices_from(mle_chol)]
    
    # data = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0).T

    fun = jit(lambda x, y: ll(x, y, data))
    # gra = jit(grad(fun))
    init = (jnp.identity(p), jnp.ones(shape=(p,)))
    # print(fun(init[0], init[1]))

    # ll_mle = fun(MLE)
    
    res_cg = optimization('rcg', manifold, fun=fun, init=init)
    res_bfgs = optimization('rlbfgs', manifold, fun=fun, init=init)

    fun = jit(lambda x, y: ll_chol(x, y, data))
    init = (jnp.identity(p)[jnp.triu_indices(p)], jnp.ones(shape=(p,)))
    # gra = jit(grad(fun))

    # ll_mle_chol = fun(mle_chol)

    res_cho = optimization('chol', fun=fun, init=init)
    
    return p, *res_cg, *res_bfgs, *res_cho
Example #23
0
                #print('Maxiterations reached')
                break
            if jnp.isclose(f0, old_f0, rtol=tol):
                #print('Function not changing')
                break
            if (gr_sig_norm <= tol) and (gr_the_norm <= tol):
                #print('Reached mingradnorm')
                break

            old_f0 = f0
        toc = time()
        res.append([p, k, toc - tic, f0])

        tic = time()
        init_chol = jnp.append(
            jnp.identity(p)[jnp.triu_indices(p)], jnp.ones(shape=(p, )))
        fun_chol = jit(lambda x: ll_chol(x, data))
        gra_chol = jit(grad(fun_chol))
        res_chol = minimize(fun_chol,
                            init_chol,
                            method='cg',
                            jac=gra_chol,
                            tol=tol)
        toc = time()
        res[-1] = res[-1] + [res_chol['nit'], toc - tic, res_chol['fun']]

df = pd.DataFrame(data=res,
                  columns=[
                      'p', 'riem_iter', 'riem_time', 'riem_fun', 'chol_iter',
                      'chol_time', 'chol_fun'
                  ])
Example #24
0
        #                          maxiter=maxiter, mingradnorm=tol,
        #                          verbosity=0, logverbosity=logs)

        optimizers = [
            optim_rcg,
            optim_rsd,
            #optim_rlbfgs
        ]
        RNG, key = random.split(RNG)
        data, t_cov, t_mu = generate_data(key, p)

        MLE_rep = t_cov, t_mu

        if chol:
            MLE_chol = jnp.linalg.cholesky(t_cov)
            MLE_chol = jnp.append(MLE_chol.T[jnp.triu_indices(p)], t_mu)

        def nloglik(X):
            sigma = X[0]
            theta = X[1]
            return ll(sigma, theta, data)

        if chol:

            def nloglik_chol(X):
                return ll_chol(X, data)

            fun_chol = jit(nloglik_chol)
            gra_chol = jit(grad(fun_chol))

            true_fun_chol = fun_chol(MLE_chol)
Example #25
0
def ll_chol(pars, y):
    p = y.shape[-1]
    X, theta = pars[:-p], pars[-p:]
    sigma = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), X).T
    sigma = jnp.matmul(sigma, sigma.T)
    return ll(sigma, theta, y)
Example #26
0
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray,
                     is_normalized: bool) -> jnp.ndarray:
    """Generates derivatives of associated Legendre functions of the first kind.

  Args:
    p: The 3D array containing the values of associated Legendre functions; the
      dimensions are in the sequence of order (m), degree (l), and evalution
      points.
    x: A vector of type `float32` or `float64` containing the sampled points.
    is_normalized: True if the associated Legendre functions are normalized.
  Returns:
    The 3D array representing the derivatives of associated Legendre functions
    of the first kind.
  """

    num_m, num_l, num_x = p.shape

    # p_{l-1}^m.
    p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :]

    # p_{l-1}^{m+2}.
    p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :]

    # p_{l-1}^{m-2}.
    p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :]

    # Derivative computation requires negative orders.
    if is_normalized:
        raise NotImplementedError(
            'Negative orders for normalization is not implemented yet.')
    else:
        if num_l > 1:
            l_vec = jnp.arange(1, num_l - 1)
            p_p1 = p[1, 1:num_l - 1, :]
            coeff = -1.0 / ((l_vec + 1) * l_vec)
            update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1)
            p_mm2_lm1 = p_mm2_lm1.at[1, 2:num_l, :].set(update_p_p1)

        if num_l > 2:
            l_vec = jnp.arange(2, num_l - 1)
            p_p2 = p[2, 2:num_l - 1, :]
            coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec)
            update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2)
            p_mm2_lm1 = p_mm2_lm1.at[0, 3:num_l, :].set(update_p_p2)

    m_mat, l_mat = jnp.mgrid[:num_m, :num_l]

    coeff_zeros = jnp.zeros((num_m, num_l))
    upper_0_indices = jnp.triu_indices(num_m, 0, num_l)
    zero_vec = jnp.zeros((num_l, ))

    a0 = -0.5 / (m_mat - 1.0)
    a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices])
    a0_masked = a0_masked.at[1, :].set(zero_vec)

    b0 = l_mat + m_mat
    c0 = a0 * (b0 - 2.0) * (b0 - 1.0)
    c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices])
    c0_masked = c0_masked.at[1, :].set(zero_vec)

    # p_l^{m-1}.
    p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) +
               jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1))

    d0 = -0.5 / (m_mat + 1.0)
    d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices])
    e0 = d0 * b0 * (b0 + 1.0)
    e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices])

    # p_l^{m+1}.
    p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) +
               jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1))

    f0 = b0 * (l_mat - m_mat + 1.0) / 2.0
    f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices])
    p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked,
                              p_mm1_l) - 0.5 * p_mp1_l

    # Special treatment of the singularity at m = 1.
    if num_m > 1:
        l_vec = jnp.arange(num_l)
        g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :])
        if num_l > 2:
            g0 = g0 - p[2, :, :]
        p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0)
        p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0)
        p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, )))

    return p_derivative
Example #27
0
fun_rep = jit(nloglik)
gra_rep = jit(grad(fun_rep))

true_fun_rep = fun_rep(MLE_rep)
true_gra_rep = gra_rep(MLE_rep)
true_grnorm_rep = man.norm(MLE_rep, true_gra_rep)
# print('Reparametrized function on MLE: ', true_fun_rep)
# print('Gradient norm of reparametrized function on MLE: ', true_grnorm_rep)

init_rep = jnp.identity(p + 1)
init_cho = jnp.ones_like(MLE_chol)

print('Start conjugate gradient optimization...')
result_rcg = optim_rcg.solve(fun_rep, gra_rep, x=init_rep)
result_rcg.pprint()

print('Start riemannian descent optimization...')
result_rsd = optim_rsd.solve(fun_rep, gra_rep, x=init_rep)
result_rcg.pprint()

print('Start cholesky optimization...')
start = time()
result_cho = minimize(fun_chol, init_cho, method='cg', jac=gra_chol, tol=tol)
cov = index_update(
    jnp.zeros(shape=(p+1, p+1)),
    jnp.triu_indices(p+1),
    res.x).T
time_cho = time() - start
print("{}\n\t{} iterations in {:.2f} s".format(result_cho['message'],
      result_cho['nit'], time_cho))
Example #28
0
 def urt(x, box):
     distance_matrix = distance(x, box)
     i, j = np.triu_indices(len(distance_matrix), k=1)
     return distance_matrix[i, j]
Example #29
0
chol_gra = [jnp.linalg.norm(gra_chol(init_chol))]


def store(X):
    chol_fun.append(func_chol(X))
    chol_gra.append(jnp.linalg.norm(gra_chol(X)))


res = minimize(func_chol,
               init_chol,
               method='newton-cg',
               jac=gra_chol,
               callback=store,
               options={'disp': True})
chol, mu_chol = res.x[:-p], res.x[-p:]
sig_chol = index_update(jnp.zeros(shape=(p, p)), jnp.triu_indices(p), chol).T
sig_chol = jnp.einsum('ij,kj', sig_chol, sig_chol)

chol_fun = jnp.array(chol_fun)
chol_gra = jnp.array(chol_gra)

toc = time()
########################################

########################################
## Print results:
man_2 = SPD(p)
print("\n=================\n\tResults:\n")
print("Full Riemannian:")
print("\tStarting loglik {:.5e}".format(func(startmu, startsig)))
print("\tTime spent {:.2f} s".format(res_riem.time))
Example #30
0
            res = index_update(res, index[i, run, 4],
                               man.dist(result.x, MLE_rep))
            res = index_update(res, index[i, run, 5], result.grnorm)
            res = index_update(res, index[i, run, 6], i)

        if chol:
            start = time()
            result = minimize(fun_chol,
                              init_chol,
                              method='cg',
                              jac=gra_chol,
                              options={'maxiter': maxiter_chol},
                              tol=tol)
            # print("{} {} iterations in {:.2f} s".format(res['message'], res['nit'], time() - start))
            cov = index_update(jnp.zeros(shape=(p + 1, p + 1)),
                               jnp.triu_indices(p + 1), result.x).T
            res_cho = index_update(res_cho, index[run, 0], p)
            res_cho = index_update(res_cho, index[run, 1], time() - start)
            res_cho = index_update(res_cho, index[run, 2], result['nit'])
            res_cho = index_update(res_cho, index[run, 3],
                                   (result['fun'] - true_fun_chol) /
                                   true_fun_chol)
            res_cho = index_update(res_cho, index[run, 4],
                                   man.dist(cov @ cov.T, MLE_rep))
            res_cho = index_update(res_cho, index[run, 5],
                                   jnp.linalg.norm(result.jac))
            res_cho = index_update(res_cho, index[run, 6], 3)

    columns = [
        'Matrix dimension', 'Time', 'Iterations', 'Function difference',
        'Matrix distance', 'Gradient norm', 'Algorithm'