Exemplo n.º 1
0
  def test_avg_pool(self):
    X1 = np.ones((4, 2, 3, 2))
    X2 = np.ones((3, 2, 3, 2))

    _, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1), 'SAME',
                                          normalize_edges=False)
    _, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1), 'SAME',
                                                    normalize_edges=True)
    _, apply_fn_stax = stax.ostax.AvgPool((2, 2), (1, 1), 'SAME')

    out1 = apply_fn((), X1)
    out2 = apply_fn((), X2)

    out1_norm = apply_fn_norm((), X1)
    out2_norm = apply_fn_norm((), X2)

    out1_stax = apply_fn_stax((), X1)
    out2_stax = apply_fn_stax((), X2)

    self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm), True)

    out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape(
        (1, 2, 3, 1))
    out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape)
    out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape)

    self.assertAllClose((out1_unnormalized, out2_unnormalized), (out1, out2),
                        True)

    ker = kernel_fn(X1, X2)
    ker_norm = kernel_fn_norm(X1, X2)

    self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp, True)
    self.assertAllClose(np.ones_like(ker_norm.var1), ker_norm.var1, True)
    self.assertAllClose(np.ones_like(ker_norm.var2), ker_norm.var2, True)

    self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape)
    self.assertEqual(ker_norm.var1.shape, ker.var1.shape)
    self.assertEqual(ker_norm.var2.shape, ker.var2.shape)

    ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3))
    ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3))
    nngp = np.broadcast_to(
        ker_unnorm.reshape((1, 1) + ker_unnorm.shape), ker.nngp.shape)
    var1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var1.shape)
    var2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var2.shape)
    self.assertAllClose((nngp, var1, var2), (ker.nngp, ker.var1, ker.var2),
                        True)
Exemplo n.º 2
0
    def update(self, x, y_true, params, averager=None):
        # Run forward pass.
        z1, h1, z2, h2 = self.forward(x, params, return_activations=True)

        # Compute error for final layer (= gradient of cost w.r.t layer input).
        e2 = h2 - y_true  # gradient through cross entropy loss

        # Compute gradients of cost w.r.t. parameters.
        grad_b2 = e2
        grad_W2 = np.outer(h1, e2)

        # Update parameters.
        self.b2 -= params['lr'] * grad_b2
        self.W2 -= params['lr'] * grad_W2

        return h2
Exemplo n.º 3
0
 def update_fn(sample, state):
     """
     :param sample: A new sample.
     :param state: Current state of the scheme.
     :return: new state for the scheme.
     """
     mean, m2, n = state
     n = n + 1
     delta_pre = sample - mean
     mean = mean + delta_pre / n
     delta_post = sample - mean
     if diagonal:
         m2 = m2 + delta_pre * delta_post
     else:
         m2 = m2 + np.outer(delta_post, delta_pre)
     return mean, m2, n
Exemplo n.º 4
0
def qmult(key, b):
    """
    QMULT  Pre-multiply by random orthogonal matrix.
       QMULT(A) is Q*A where Q is a random real orthogonal matrix from
       the Haar distribution, of dimension the number of rows in A.
       Special case: if A is a scalar then QMULT(A) is the same as
                     QMULT(EYE(A)).
       Called by RANDSVD.
       Reference:
       G.W. Stewart, The efficient generation of random
       orthogonal matrices with an application to condition estimators,
       SIAM J. Numer. Anal., 17 (1980), 403-409.
    """
    try:
        n = b.shape[0]
        a = b.copy()
    except AttributeError:
        n = b
        a = np.eye(n)

    d = np.zeros(n)
    for k in range(n - 2, -1, -1):
        # Generate random Householder transformation.
        key, subkey = random.split(key)
        x = random.normal(subkey, (n - k, ))
        s = np.linalg.norm(x)

        # Modification to make sign(0) == 1
        sgn = np.sign(x[0]) + float(x[0] == 0)
        s = sgn * s
        d = index_update(d, k, -sgn)
        x = index_update(x, 0, x[0] + s)
        beta = s * x[0]

        # Apply the transformation to a
        y = np.dot(x, a[k:n, :])
        a = index_update(a, index[k:n, :], a[k:n, :] - np.outer(x, (y / beta)))

    # Tidy up signs.
    for i in range(n - 1):
        a = index_update(a, index[i, :], d[i] * a[i, :])

    # Now randomly change the sign (Gaussian dist)
    a = index_update(a, index[n - 1, :],
                     a[n - 1, :] * np.sign(random.normal(key, ())))
    return a
    def testOrderOneDegreeOne(self):
        """Tests the spherical harmonics of order one and degree one."""
        num_theta = 7
        num_phi = 8
        theta = jnp.linspace(0, math.pi, num_theta)
        phi = jnp.linspace(0, 2.0 * math.pi, num_phi)

        expected = -1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * math.pi)) * jnp.outer(
            jnp.sin(theta), jnp.exp(1j * phi))
        sph_harm = spherical_harmonics.SphericalHarmonics(l_max=1,
                                                          theta=theta,
                                                          phi=phi)
        actual = sph_harm.harmonics_nonnegative_order()[1, 1, :, :]
        np.testing.assert_allclose(jnp.abs(actual),
                                   jnp.abs(expected),
                                   rtol=1e-8,
                                   atol=6e-8)
Exemplo n.º 6
0
def var(kappa: float, mu: jnp.ndarray) -> jnp.ndarray:
    """Compute the variance of the power spherical distribution.

    Args:
        kappa: Concentration parameter.
        mu: Mean direction on the sphere. The dimensionality of the sphere is
            determined from this paramter.

    Returns:
        out: The variance of the power spherical distribution.

    """
    d = mu.size
    alpha = (d - 1.) / 2. + kappa
    beta = (d - 1.) / 2.
    return (2 * alpha / ((alpha + beta)**2 * (alpha + beta + 1.)) *
            ((beta - alpha) * jnp.outer(mu, mu) + (alpha + beta) * jnp.eye(d)))
Exemplo n.º 7
0
def linreg_imputation_model(X, y):

    ndims = X.shape[1]
    a = numpyro.sample("a", dist.Normal(0, 0.5))

    beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims]))
    sigma_y = numpyro.sample("sigma_y", dist.Exponential(1))

    # X_impute contains imputed data for each feature as a list
    # X_merged is the observed data filled with imputed values at missing points.
    X_impute = [None] * ndims
    X_merged = [None] * ndims

    for i in range(ndims):  # for every feature
        no_of_missed = int(np.isnan(X[:, i]).sum())

        if no_of_missed != 0:
            # each nan value is associated with a imputed variable of std normal prior.
            X_impute[i] = numpyro.sample(
                "X_impute_{}".format(i),
                dist.Normal(0, 1).expand([no_of_missed]).mask(False))

            # merging the observed data with the imputed values.
            missed_idx = np.nonzero(np.isnan(X[:, i]))[0]
            X_merged[i] = ops.index_update(X[:, i], missed_idx, X_impute[i])

        # if there are no missing values, its just the observed data.
        else:
            X_merged[i] = X[:, i]

    merged_X = jnp.stack(X_merged).T

    # LKJ is the distribution to model correlation matrices.
    rho = numpyro.sample("rho", dist.LKJ(ndims, 2))  # correlation matrix
    sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims]))
    covariance_x = jnp.outer(sigma_x, sigma_x) * rho  # covariance matrix
    mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims]))

    numpyro.sample("X_merged",
                   dist.MultivariateNormal(mu_x, covariance_x),
                   obs=merged_X)

    mu_y = a + merged_X @ beta

    numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
Exemplo n.º 8
0
def weighted_sum(mean, cov, weights):
    """
    Computes mean and variance of a weighted sum of the mvn r.v.
    Args:
        mean (np.array): The mean of the MVN.
        cov (np.array): The covariance of the MVN.
        weights (np.array): A vector of weights to give the elements.
    Returns:
        Tuple[float, float]: The mean and variance of the weighted sum.
    """

    mean_summed_theta = np.dot(mean, weights)

    outer_x = np.outer(weights, weights)
    multiplied = cov * outer_x
    weighted_sum = np.sum(multiplied)

    return mean_summed_theta, weighted_sum
Exemplo n.º 9
0
def linreg_model(X, y):

    ndims = X.shape[1]
    a = numpyro.sample("a", dist.Normal(0, 0.5))

    beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims]))
    sigma_y = numpyro.sample("sigma_y", dist.Exponential(1))

    # LKJ is the distribution to model correlation matrices.
    rho = numpyro.sample("rho", dist.LKJ(ndims, 2))  # correlation matrix
    sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims]))
    covariance_x = jnp.outer(sigma_x, sigma_x) * rho  # covariance matrix
    mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims]))

    numpyro.sample("X", dist.MultivariateNormal(mu_x, covariance_x), obs=X)

    mu_y = a + X @ beta

    numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
Exemplo n.º 10
0
def householder_product(inputs: Array, q_vector: Array) -> Array:
    """
    Args:
        inputs (Array) : inputs for the householder product
        (D,)
        q_vector (Array): vector to be multiplied
        (D,)
    
    Returns:
        outputs (Array) : outputs after the householder product
    """
    # norm for q_vector
    squared_norm = jnp.sum(q_vector**2)
    # inner product
    temp = jnp.dot(inputs, q_vector)
    # outer product
    temp = jnp.outer(temp, (2.0 / squared_norm) * q_vector).squeeze()
    # update
    output = inputs - temp
    return output
        def update_fun(step, grads, state):
            """Apply a step of the optimzier."""
            del step  # Unused.
            params, grad_seq = state

            # Update gradient history.
            grad_seq = append_to_sequence(grad_seq, grads)

            # Compute normalized gram matrix.
            gram = innerprod(grad_seq, grad_seq)
            grad_norm = norms(grad_seq)
            gram /= (jnp.outer(grad_norm, grad_norm) + 1e-6)

            # Compute update terms.
            attn_weights = jnp.dot(stax.softmax(gram, axis=0), theta_gram)
            attn_term = jnp.tensordot(attn_weights, grad_seq, axes=1)
            grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1)
            params -= (grad_term + attn_term)

            return (params, grad_seq)
Exemplo n.º 12
0
    def nngp_ntk_fn(nngp, q11, q22, ntk=None):
      """Simple Gauss-Hermite quadrature routine."""
      xs, ws = quad_points
      grid = np.outer(ws, ws)
      x = xs.reshape((xs.shape[0],) + (1,) * (nngp.ndim + 1))
      y = xs.reshape((1, xs.shape[0]) + (1,) * nngp.ndim)
      xy_axes = (0, 1)

      nngp = np.expand_dims(nngp, xy_axes)
      q11, q22 = np.expand_dims(q11, xy_axes), np.expand_dims(q22, xy_axes)

      def integrate(f):
        fvals = f(_sqrt(2 * q11) * x) * f(
            nngp / _sqrt(q11 / 2, 1e-30) * x + _sqrt(
                2*(q22 - nngp**2/q11)) * y)
        return np.tensordot(grid, fvals, (xy_axes, xy_axes)) / np.pi

      if ntk is not None:
        ntk *= integrate(df)
      nngp = integrate(fn)
      return nngp, ntk
Exemplo n.º 13
0
def linear_matter_power(cosmo,
                        k,
                        a=1.0,
                        transfer_fn=tklib.Eisenstein_Hu,
                        **kwargs):
    r""" Computes the linear matter power spectrum.

  Parameters
  ----------
  k: array_like
      Wave number in h Mpc^{-1}

  a: array_like, optional
      Scale factor (def: 1.0)

  transfer_fn: transfer_fn(cosmo, k, **kwargs)
      Transfer function

  Returns
  -------
  pk: array_like
      Linear matter power spectrum at the specified scale
      and scale factor.

  """
    k = np.atleast_1d(k)
    a = np.atleast_1d(a)
    g = bkgrd.growth_factor(cosmo, a)
    t = transfer_fn(cosmo, k, **kwargs)

    pknorm = cosmo.sigma8**2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs)

    if k.ndim == 1:
        pk = np.outer(primordial_matter_power(cosmo, k) * t**2, g**2)
    else:
        pk = primordial_matter_power(cosmo, k) * t**2 * g**2

    # Apply normalisation
    pk = pk * pknorm
    return pk.squeeze()
Exemplo n.º 14
0
def _calc_vars(theta: np.array):
    """
    Calculate the mean and variance of the posterior.


    mu = theta(1, 5) * mask(5, 2) = (1, 2) vector
    s_1, s_2 = theta(n, 5) * mask(5, 2) * theta(5, n) = (n,) scalar
    rho = theta(1, 5) * mask(5, 1) = (1, 1) vector

    Sigma = [[s_1 ** 2, rho * s_1 * s_2],
            [rho * s_1 * s_2, s_2 ** 2]]

    Sigma = outer([s_1, s_2], [s_1, s_2]) * elementwise_mult rho
    """
    # have to do it this way to allow vectorization
    s_vec = np.square(theta[2:4])
    rho = np.tanh(theta[5])
    rho_matrix = np.eye(2) + rho * np.eye(2)[::-1]  # off-diagonal rho

    mu = theta[:2]
    Sigma = np.outer(s_vec, s_vec) * rho_matrix
    return mu, Sigma
Exemplo n.º 15
0
    def update(state: WelfordAlgorithmState,
               value: np.DeviceArray) -> WelfordAlgorithmState:
        """Update the M2 matrix using the new value.

        Parameters
        ----------
        state:
            The current state of the Welford Algorithm
        position: jax.numpy.DeviceArray, shape (1,)
            The new sample (typically position of the chain) used to update m2
        """
        mean, m2, sample_size = state
        sample_size = sample_size + 1

        delta = value - mean
        mean = mean + delta / sample_size
        updated_delta = value - mean
        if is_diagonal_matrix:
            new_m2 = m2 + delta * updated_delta
        else:
            new_m2 = m2 + np.outer(updated_delta, delta)

        return WelfordAlgorithmState(mean, new_m2, sample_size)
def positional_encoding(seq_len, embed_dim, timescale=10000):
    """
  Returns positional encoding values.
  Assumes seq dimensions are (batch, seq_len, word_space)

  Output shape:
    seq_len x embed_dim
  """

    if embed_dim % 2 != 0:
        raise ValueError("Embedding dimension must be even")

    positions = jnp.arange(seq_len)
    i = jnp.arange(embed_dim // 2)
    angular_frequencies = 1 / jnp.power(timescale, 2 * i / embed_dim)

    angles = jnp.outer(positions, angular_frequencies)
    cosine = jnp.cos(angles)  # seq_len, embed_dim // 2
    sine = jnp.sin(angles)  # seq_len, embed_dim // 2

    pos_enc = jnp.concatenate([cosine, sine], axis=1)

    return pos_enc
Exemplo n.º 17
0
    def update(state: WelfordAlgorithmState,
               value: np.DeviceArray) -> WelfordAlgorithmState:
        """Update the M2 matrix using the new value.

        Arguments:
        ----------
        state:
            The current state of the Welford Algorithm
        value: jax.numpy.DeviceArray, shape (1,)
            The new sample used to update m2
        """
        mean, m2, count = state
        count = count + 1

        delta = value - mean
        mean = mean + delta / count
        updated_delta = value - mean
        if is_diagonal_matrix:
            m2 = m2 + delta * updated_delta
        else:
            m2 = m2 + np.outer(delta, updated_delta)

        return WelfordAlgorithmState(mean, m2, count)
Exemplo n.º 18
0
def house_rightmult(A, v, beta):
    """
    Given the m x n matrix A and the length-n vector v with normalization
    beta such that P = I - beta v otimes dag(v) is the Householder matrix that
    reflects about v, compute AP.

    Parameters
    ----------
    A:  array_like, shape(M, N)
        Matrix to be multiplied by H.

    v:  array_like, shape(N).
        Householder vector.

    beta: float
        Householder normalization.

    Returns
    -------
    C = AP
    """
    C = A - jnp.outer(A @ v, beta * dag(v))
    return C
Exemplo n.º 19
0
def mkcovdiag_ASD(len_sc, rho, nxcirc, wvec=None, wwnrm=None):
    #  Eigenvalues of ASD covariance (as diagonalized in Fourier domain)
    #
    #  [cdiag,dcdiag,ddcdiag] = mkcovdiag_ASD(rho,l,nxcirc,wvecsq)
    #
    #  Compute discrete ASD (RBF kernel) eigenspectrum using frequencies in [0, nxcirc].
    #  See mkCov_ASD_factored for more info
    #
    #  INPUT (all python 1d lists!):
    #          len - length scale of ASD kernel (determines smoothness)
    #          rho - maximal prior variance ("overall scale")
    #       nxcirc - number of coefficients to consider for circular boundary
    #         wvec - vector of freq for DFT
    #		wwnrm - vector of freq for DFT (normalized)
    #
    #  OUTPUT:
    #      cdiag [nxcirc x 1] - vector of eigenvalues of C for frequencies in w
    #
    # Note: nxcirc = nx corresponds to having a circular boundary

    # Compute diagonal of ASD covariance matrix
    if wvec is not None:
        wvecsq = np.square(wvec)
        const = np.square(2 * np.pi / nxcirc)  # constant
        ww = wvecsq * const  # effective frequency vector
    elif wwnrm is not None:
        ww = wwnrm
    else:
        print(
            "please provide either wvec or a normalized wvec into this function"
        )

    cdiag = np.squeeze(
        np.sqrt(2 * np.pi) * rho * len_sc *
        np.exp(-.5 * np.outer(ww, np.square(len_sc))))
    return cdiag
Exemplo n.º 20
0
 def searchphase(y, x):
     y_t = jnp.outer(y, TPS)
     x_t = jnp.tile(x[:,None], (1, testing_phases))
     snr_t = 10. * jnp.log10(getpower(y_t) / getpower(y_t - x_t))
     return TPS[jnp.argmax(snr_t)]
Exemplo n.º 21
0
def _gyration_tensor(positions):
    n = positions.shape[0]
    S = np.zeros((3, 3))
    for r in positions:
        S += np.outer(r, r)
    return S / n
Exemplo n.º 22
0
 def outer(self, tensor_in_1, tensor_in_2):
     return jnp.outer(tensor_in_1, tensor_in_2)
Exemplo n.º 23
0
        def affine_transform(dist_params, scale, shift, value_transform=None):
            """ implements the "Categorical Algorithm" from https://arxiv.org/abs/1707.06887 """

            # check inputs
            chex.assert_rank([dist_params['logits'], scale, shift],
                             [2, {0, 1}, {0, 1}])
            p = jax.nn.softmax(dist_params['logits'])
            batch_size = p.shape[0]

            if isscalar(scale):
                scale = jnp.full(shape=(batch_size, ),
                                 fill_value=jnp.squeeze(scale))
            if isscalar(shift):
                shift = jnp.full(shape=(batch_size, ),
                                 fill_value=jnp.squeeze(shift))

            chex.assert_shape(p, (batch_size, self.num_bins))
            chex.assert_shape([scale, shift], (batch_size, ))

            if value_transform is None:
                f = f_inv = lambda x: x
            else:
                f, f_inv = value_transform

            # variable names correspond to those defined in: https://arxiv.org/abs/1707.06887
            z = self.__atoms
            Vmin, Vmax, Δz = z[0], z[-1], z[1] - z[0]
            Tz = f(jax.vmap(jnp.add)(jnp.outer(scale, f_inv(z)), shift))
            Tz = jnp.clip(Tz, Vmin, Vmax)  # keep values in valid range
            chex.assert_shape(Tz, (batch_size, self.num_bins))

            b = (Tz - Vmin) / Δz  # float in [0, num_bins - 1]
            l = jnp.floor(b).astype(
                'int32')  # noqa: E741   # int in {0, 1, ..., num_bins - 1}
            u = jnp.ceil(b).astype('int32')  # int in {0, 1, ..., num_bins - 1}
            chex.assert_shape([p, b, l, u], (batch_size, self.num_bins))

            m = jnp.zeros_like(p)
            i = jnp.expand_dims(jnp.arange(batch_size), axis=1)  # batch index
            m = jax.ops.index_add(m, (i, l),
                                  p * (u - b),
                                  indices_are_sorted=True)
            m = jax.ops.index_add(m, (i, u),
                                  p * (b - l),
                                  indices_are_sorted=True)
            m = jax.ops.index_add(m, (i, l),
                                  p * (l == u),
                                  indices_are_sorted=True)
            # chex.assert_tree_all_close(jnp.sum(m, axis=1), jnp.ones(batch_size), rtol=1e-6)

            # # The above index trickery is equivalent to:
            # m_alt = onp.zeros((batch_size, self.num_bins))
            # for i in range(batch_size):
            #     for j in range(self.num_bins):
            #         if l[i, j] == u[i, j]:
            #             m_alt[i, l[i, j]] += p[i, j]  # don't split if b[i, j] is an integer
            #         else:
            #             m_alt[i, l[i, j]] += p[i, j] * (u[i, j] - b[i, j])
            #             m_alt[i, u[i, j]] += p[i, j] * (b[i, j] - l[i, j])
            # chex.assert_tree_all_close(m, m_alt, rtol=1e-6)
            return {'logits': jnp.log(jnp.maximum(m, 1e-16))}
Exemplo n.º 24
0
        optimizers = [
            optim_rcg,
            optim_rsd,
            #optim_rlbfgs
        ]
        RNG, key = random.split(RNG)
        t_cov, t_mu = orig_man.rand(key)
        RNG, key = random.split(RNG)
        data = random.multivariate_normal(key,
                                          mean=t_mu,
                                          cov=t_cov,
                                          shape=(N, ))
        s_mu = jnp.mean(data, axis=0)
        s_cov = jnp.dot((data - s_mu).T, data - s_mu) / N

        MLE_rep = jnp.append(jnp.append(s_cov + jnp.outer(s_mu, s_mu),
                                        jnp.array([s_mu]),
                                        axis=0),
                             jnp.array([jnp.append(s_mu, 1)]).T,
                             axis=1)
        if chol:
            MLE_chol = jnp.linalg.cholesky(MLE_rep)
            MLE_chol = MLE_chol.T[~(MLE_chol.T == 0.)].ravel()

        def nloglik(X):
            y = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0)
            datapart = jnp.trace(jnp.linalg.solve(X, jnp.matmul(y, y.T)))
            return 0.5 * (N * jnp.linalg.slogdet(X)[1] + datapart)

        if chol:
def compute_OBC_energy_vectorized(
    distance_matrix,
    radii,
    scales,
    charges,
    offset=0.009,
    screening=138.935484,
    surface_tension=28.3919551,
    solvent_dielectric=78.5,
    solute_dielectric=1.0,
):
    """Compute GBSA-OBC energy from a distance matrix"""
    N = len(radii)
    #print(type(distance_matrix))
    eye = np.eye(N, dtype=distance_matrix.dtype)
    #print(type(eye))
    r = distance_matrix + eye  # so I don't have divide-by-zero nonsense
    or1 = radii.reshape((N, 1)) - offset
    or2 = radii.reshape((1, N)) - offset
    sr2 = scales.reshape((1, N)) * or2

    L = np.maximum(or1, abs(r - sr2))
    U = r + sr2
    I = step(r + sr2 - or1) * 0.5 * (1 / L - 1 / U + 0.25 * (r - sr2**2 / r) *
                                     (1 / (U**2) - 1 /
                                      (L**2)) + 0.5 * np.log(L / U) / r)

    I -= np.diag(np.diag(I))
    I = np.sum(I, axis=1)

    # okay, next compute born radii
    offset_radius = radii - offset
    psi = I * offset_radius
    psi_coefficient = 0.8
    psi2_coefficient = 0
    psi3_coefficient = 2.909125

    psi_term = (psi_coefficient * psi) + (psi2_coefficient *
                                          psi**2) + (psi3_coefficient * psi**3)

    B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii)

    # finally, compute the three energy terms
    E = 0.0

    # single particle
    E += np.sum(surface_tension * (radii + 0.14)**2 * (radii / B)**6)
    E += np.sum(-0.5 * screening *
                (1 / solute_dielectric - 1 / solvent_dielectric) * charges**2 /
                B)

    # particle pair
    f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B))))
    charge_products = np.outer(charges, charges)

    E += np.sum(
        np.triu(-screening * (1 / solute_dielectric - 1 / solvent_dielectric) *
                charge_products / f,
                k=1))

    return E
Exemplo n.º 26
0
Arquivo: ops.py Projeto: dfm/celeriac
def _matmul_impl(state: Carry, data: MatmulData) -> Tuple[Carry, Array]:
    (Fp, Vp, Yp) = state
    Vn, Pn, Yn = data
    Fn = _pdot(Pn, Fp + jnp.outer(Vp, Yp))
    return (Fn, Vn, Yn), Fn
Exemplo n.º 27
0
Arquivo: ops.py Projeto: dfm/celeriac
def _solve_impl(state: Carry, data: Data) -> Tuple[Carry, Array]:
    Fp, Wp, Zp = state
    Un, Wn, Pn, Yn = data
    Fn = _pdot(Pn, Fp + jnp.outer(Wp, Zp))
    Zn = Yn - Un @ Fn
    return (Fn, Wn, Zn), Zn
Exemplo n.º 28
0
    def __init__(self, n_samp, seed, outdtype, pardtype, holomorphic):

        self.dtype = outdtype

        self.target = {
            "a": jnp.array([[[0j], [0j]]], dtype=jnp.complex128),
            "b": jnp.array(0, dtype=jnp.float64),
            "c": jnp.array(0j, dtype=jnp.complex64),
        }

        if pardtype is None:  # mixed precision as above
            pass
        else:
            self.target = jax.tree_map(
                lambda x: astype_unsafe(x, pardtype),
                self.target,
            )

        k = jax.random.PRNGKey(seed)
        k1, k2, k3, k4, k5 = jax.random.split(k, 5)

        self.samples = jax.random.normal(k1, (n_samp, 2))
        self.w = jax.random.normal(k2, (n_samp,), self.dtype).astype(
            self.dtype
        )  # TODO remove astype once its fixed in jax
        self.params = tree_random_normal_like(k3, self.target)
        self.v = tree_random_normal_like(k4, self.target)
        self.grad = tree_random_normal_like(k5, self.target)

        if holomorphic:

            @partial(jax.vmap, in_axes=(None, 0))
            def f(params, x):
                return astype_unsafe(
                    params["a"][0][0][0] * x[0]
                    + params["b"] * x[1]
                    + params["c"] * (x[0] * x[1])
                    + jnp.sin(x[1] * params["a"][0][1][0])
                    * jnp.cos(x[0] * params["b"] + 1j)
                    * params["c"],
                    self.dtype,
                )

        else:

            @partial(jax.vmap, in_axes=(None, 0))
            def f(params, x):
                return astype_unsafe(
                    params["a"][0][0][0].conjugate() * x[0]
                    + params["b"] * x[1]
                    + params["c"] * (x[0] * x[1])
                    + jnp.sin(x[1] * params["a"][0][1][0])
                    * jnp.cos(x[0] * params["b"].conjugate() + 1j)
                    * params["c"].conjugate(),
                    self.dtype,
                )

        self.f = f

        self.params_real_flat = tree_toreal_flat(self.params)
        self.grad_real_flat = tree_toreal_flat(self.grad)
        self.v_real_flat = tree_toreal_flat(self.v)
        self.ok_real = self.grads_real(self.params_real_flat, self.samples)
        self.okmean_real = self.ok_real.mean(axis=0)
        self.dok_real = self.ok_real - self.okmean_real
        self.S_real = (
            self.dok_real.conjugate().transpose() @ self.dok_real / n_samp
        ).real
        self.scale = jnp.sqrt(self.S_real.diagonal())
        self.S_real_scaled = self.S_real / (jnp.outer(self.scale, self.scale))
Exemplo n.º 29
0
def normal_expected_stats(normal):
    mu, sigma = normal_nat_to_std(*normal)
    t1 = np.outer(mu, mu) + sigma
    t2 = mu
    return t1, t2
Exemplo n.º 30
0
def niw_std_to_nat(mu, kappa, psi, nu):
    n1 = kappa * np.outer(mu, mu) + psi
    n2 = kappa * mu
    n3 = kappa
    n4 = nu + psi.shape[0] + 2
    return n1, n2, n3, n4