Exemplo n.º 1
  def test_avg_pool(self):
    X1 = np.ones((4, 2, 3, 2))
    X2 = np.ones((3, 2, 3, 2))

    _, apply_fn, kernel_fn = stax.AvgPool((2, 2), (1, 1), 'SAME',
    _, apply_fn_norm, kernel_fn_norm = stax.AvgPool((2, 2), (1, 1), 'SAME',
    _, apply_fn_stax = stax.ostax.AvgPool((2, 2), (1, 1), 'SAME')

    out1 = apply_fn((), X1)
    out2 = apply_fn((), X2)

    out1_norm = apply_fn_norm((), X1)
    out2_norm = apply_fn_norm((), X2)

    out1_stax = apply_fn_stax((), X1)
    out2_stax = apply_fn_stax((), X2)

    self.assertAllClose((out1_stax, out2_stax), (out1_norm, out2_norm), True)

    out_unnorm = np.array([[1., 1., 0.5], [0.5, 0.5, 0.25]]).reshape(
        (1, 2, 3, 1))
    out1_unnormalized = np.broadcast_to(out_unnorm, X1.shape)
    out2_unnormalized = np.broadcast_to(out_unnorm, X2.shape)

    self.assertAllClose((out1_unnormalized, out2_unnormalized), (out1, out2),

    ker = kernel_fn(X1, X2)
    ker_norm = kernel_fn_norm(X1, X2)

    self.assertAllClose(np.ones_like(ker_norm.nngp), ker_norm.nngp, True)
    self.assertAllClose(np.ones_like(ker_norm.var1), ker_norm.var1, True)
    self.assertAllClose(np.ones_like(ker_norm.var2), ker_norm.var2, True)

    self.assertEqual(ker_norm.nngp.shape, ker.nngp.shape)
    self.assertEqual(ker_norm.var1.shape, ker.var1.shape)
    self.assertEqual(ker_norm.var2.shape, ker.var2.shape)

    ker_unnorm = np.outer(out_unnorm, out_unnorm).reshape((2, 3, 2, 3))
    ker_unnorm = np.transpose(ker_unnorm, axes=(0, 2, 1, 3))
    nngp = np.broadcast_to(
        ker_unnorm.reshape((1, 1) + ker_unnorm.shape), ker.nngp.shape)
    var1 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var1.shape)
    var2 = np.broadcast_to(np.expand_dims(ker_unnorm, 0), ker.var2.shape)
    self.assertAllClose((nngp, var1, var2), (ker.nngp, ker.var1, ker.var2),
Exemplo n.º 2
    def update(self, x, y_true, params, averager=None):
        # Run forward pass.
        z1, h1, z2, h2 = self.forward(x, params, return_activations=True)

        # Compute error for final layer (= gradient of cost w.r.t layer input).
        e2 = h2 - y_true  # gradient through cross entropy loss

        # Compute gradients of cost w.r.t. parameters.
        grad_b2 = e2
        grad_W2 = np.outer(h1, e2)

        # Update parameters.
        self.b2 -= params['lr'] * grad_b2
        self.W2 -= params['lr'] * grad_W2

        return h2
Exemplo n.º 3
 def update_fn(sample, state):
     :param sample: A new sample.
     :param state: Current state of the scheme.
     :return: new state for the scheme.
     mean, m2, n = state
     n = n + 1
     delta_pre = sample - mean
     mean = mean + delta_pre / n
     delta_post = sample - mean
     if diagonal:
         m2 = m2 + delta_pre * delta_post
         m2 = m2 + np.outer(delta_post, delta_pre)
     return mean, m2, n
Exemplo n.º 4
def qmult(key, b):
    QMULT  Pre-multiply by random orthogonal matrix.
       QMULT(A) is Q*A where Q is a random real orthogonal matrix from
       the Haar distribution, of dimension the number of rows in A.
       Special case: if A is a scalar then QMULT(A) is the same as
       Called by RANDSVD.
       G.W. Stewart, The efficient generation of random
       orthogonal matrices with an application to condition estimators,
       SIAM J. Numer. Anal., 17 (1980), 403-409.
        n = b.shape[0]
        a = b.copy()
    except AttributeError:
        n = b
        a = np.eye(n)

    d = np.zeros(n)
    for k in range(n - 2, -1, -1):
        # Generate random Householder transformation.
        key, subkey = random.split(key)
        x = random.normal(subkey, (n - k, ))
        s = np.linalg.norm(x)

        # Modification to make sign(0) == 1
        sgn = np.sign(x[0]) + float(x[0] == 0)
        s = sgn * s
        d = index_update(d, k, -sgn)
        x = index_update(x, 0, x[0] + s)
        beta = s * x[0]

        # Apply the transformation to a
        y = np.dot(x, a[k:n, :])
        a = index_update(a, index[k:n, :], a[k:n, :] - np.outer(x, (y / beta)))

    # Tidy up signs.
    for i in range(n - 1):
        a = index_update(a, index[i, :], d[i] * a[i, :])

    # Now randomly change the sign (Gaussian dist)
    a = index_update(a, index[n - 1, :],
                     a[n - 1, :] * np.sign(random.normal(key, ())))
    return a
    def testOrderOneDegreeOne(self):
        """Tests the spherical harmonics of order one and degree one."""
        num_theta = 7
        num_phi = 8
        theta = jnp.linspace(0, math.pi, num_theta)
        phi = jnp.linspace(0, 2.0 * math.pi, num_phi)

        expected = -1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * math.pi)) * jnp.outer(
            jnp.sin(theta), jnp.exp(1j * phi))
        sph_harm = spherical_harmonics.SphericalHarmonics(l_max=1,
        actual = sph_harm.harmonics_nonnegative_order()[1, 1, :, :]
Exemplo n.º 6
def var(kappa: float, mu: jnp.ndarray) -> jnp.ndarray:
    """Compute the variance of the power spherical distribution.

        kappa: Concentration parameter.
        mu: Mean direction on the sphere. The dimensionality of the sphere is
            determined from this paramter.

        out: The variance of the power spherical distribution.

    d = mu.size
    alpha = (d - 1.) / 2. + kappa
    beta = (d - 1.) / 2.
    return (2 * alpha / ((alpha + beta)**2 * (alpha + beta + 1.)) *
            ((beta - alpha) * jnp.outer(mu, mu) + (alpha + beta) * jnp.eye(d)))
Exemplo n.º 7
def linreg_imputation_model(X, y):

    ndims = X.shape[1]
    a = numpyro.sample("a", dist.Normal(0, 0.5))

    beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims]))
    sigma_y = numpyro.sample("sigma_y", dist.Exponential(1))

    # X_impute contains imputed data for each feature as a list
    # X_merged is the observed data filled with imputed values at missing points.
    X_impute = [None] * ndims
    X_merged = [None] * ndims

    for i in range(ndims):  # for every feature
        no_of_missed = int(np.isnan(X[:, i]).sum())

        if no_of_missed != 0:
            # each nan value is associated with a imputed variable of std normal prior.
            X_impute[i] = numpyro.sample(
                dist.Normal(0, 1).expand([no_of_missed]).mask(False))

            # merging the observed data with the imputed values.
            missed_idx = np.nonzero(np.isnan(X[:, i]))[0]
            X_merged[i] = ops.index_update(X[:, i], missed_idx, X_impute[i])

        # if there are no missing values, its just the observed data.
            X_merged[i] = X[:, i]

    merged_X = jnp.stack(X_merged).T

    # LKJ is the distribution to model correlation matrices.
    rho = numpyro.sample("rho", dist.LKJ(ndims, 2))  # correlation matrix
    sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims]))
    covariance_x = jnp.outer(sigma_x, sigma_x) * rho  # covariance matrix
    mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims]))

                   dist.MultivariateNormal(mu_x, covariance_x),

    mu_y = a + merged_X @ beta

    numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
Exemplo n.º 8
def weighted_sum(mean, cov, weights):
    Computes mean and variance of a weighted sum of the mvn r.v.
        mean (np.array): The mean of the MVN.
        cov (np.array): The covariance of the MVN.
        weights (np.array): A vector of weights to give the elements.
        Tuple[float, float]: The mean and variance of the weighted sum.

    mean_summed_theta = np.dot(mean, weights)

    outer_x = np.outer(weights, weights)
    multiplied = cov * outer_x
    weighted_sum = np.sum(multiplied)

    return mean_summed_theta, weighted_sum
Exemplo n.º 9
def linreg_model(X, y):

    ndims = X.shape[1]
    a = numpyro.sample("a", dist.Normal(0, 0.5))

    beta = numpyro.sample("beta", dist.Normal(0, 0.5).expand([ndims]))
    sigma_y = numpyro.sample("sigma_y", dist.Exponential(1))

    # LKJ is the distribution to model correlation matrices.
    rho = numpyro.sample("rho", dist.LKJ(ndims, 2))  # correlation matrix
    sigma_x = numpyro.sample("sigma_x", dist.Exponential(1).expand([ndims]))
    covariance_x = jnp.outer(sigma_x, sigma_x) * rho  # covariance matrix
    mu_x = numpyro.sample("mu_x", dist.Normal(0, 0.5).expand([ndims]))

    numpyro.sample("X", dist.MultivariateNormal(mu_x, covariance_x), obs=X)

    mu_y = a + X @ beta

    numpyro.sample("y", dist.Normal(mu_y, sigma_y), obs=y)
Exemplo n.º 10
def householder_product(inputs: Array, q_vector: Array) -> Array:
        inputs (Array) : inputs for the householder product
        q_vector (Array): vector to be multiplied
        outputs (Array) : outputs after the householder product
    # norm for q_vector
    squared_norm = jnp.sum(q_vector**2)
    # inner product
    temp = jnp.dot(inputs, q_vector)
    # outer product
    temp = jnp.outer(temp, (2.0 / squared_norm) * q_vector).squeeze()
    # update
    output = inputs - temp
    return output
        def update_fun(step, grads, state):
            """Apply a step of the optimzier."""
            del step  # Unused.
            params, grad_seq = state

            # Update gradient history.
            grad_seq = append_to_sequence(grad_seq, grads)

            # Compute normalized gram matrix.
            gram = innerprod(grad_seq, grad_seq)
            grad_norm = norms(grad_seq)
            gram /= (jnp.outer(grad_norm, grad_norm) + 1e-6)

            # Compute update terms.
            attn_weights = jnp.dot(stax.softmax(gram, axis=0), theta_gram)
            attn_term = jnp.tensordot(attn_weights, grad_seq, axes=1)
            grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1)
            params -= (grad_term + attn_term)

            return (params, grad_seq)
Exemplo n.º 12
    def nngp_ntk_fn(nngp, q11, q22, ntk=None):
      """Simple Gauss-Hermite quadrature routine."""
      xs, ws = quad_points
      grid = np.outer(ws, ws)
      x = xs.reshape((xs.shape[0],) + (1,) * (nngp.ndim + 1))
      y = xs.reshape((1, xs.shape[0]) + (1,) * nngp.ndim)
      xy_axes = (0, 1)

      nngp = np.expand_dims(nngp, xy_axes)
      q11, q22 = np.expand_dims(q11, xy_axes), np.expand_dims(q22, xy_axes)

      def integrate(f):
        fvals = f(_sqrt(2 * q11) * x) * f(
            nngp / _sqrt(q11 / 2, 1e-30) * x + _sqrt(
                2*(q22 - nngp**2/q11)) * y)
        return np.tensordot(grid, fvals, (xy_axes, xy_axes)) / np.pi

      if ntk is not None:
        ntk *= integrate(df)
      nngp = integrate(fn)
      return nngp, ntk
Exemplo n.º 13
def linear_matter_power(cosmo,
    r""" Computes the linear matter power spectrum.

  k: array_like
      Wave number in h Mpc^{-1}

  a: array_like, optional
      Scale factor (def: 1.0)

  transfer_fn: transfer_fn(cosmo, k, **kwargs)
      Transfer function

  pk: array_like
      Linear matter power spectrum at the specified scale
      and scale factor.

    k = np.atleast_1d(k)
    a = np.atleast_1d(a)
    g = bkgrd.growth_factor(cosmo, a)
    t = transfer_fn(cosmo, k, **kwargs)

    pknorm = cosmo.sigma8**2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs)

    if k.ndim == 1:
        pk = np.outer(primordial_matter_power(cosmo, k) * t**2, g**2)
        pk = primordial_matter_power(cosmo, k) * t**2 * g**2

    # Apply normalisation
    pk = pk * pknorm
    return pk.squeeze()
Exemplo n.º 14
def _calc_vars(theta: np.array):
    Calculate the mean and variance of the posterior.

    mu = theta(1, 5) * mask(5, 2) = (1, 2) vector
    s_1, s_2 = theta(n, 5) * mask(5, 2) * theta(5, n) = (n,) scalar
    rho = theta(1, 5) * mask(5, 1) = (1, 1) vector

    Sigma = [[s_1 ** 2, rho * s_1 * s_2],
            [rho * s_1 * s_2, s_2 ** 2]]

    Sigma = outer([s_1, s_2], [s_1, s_2]) * elementwise_mult rho
    # have to do it this way to allow vectorization
    s_vec = np.square(theta[2:4])
    rho = np.tanh(theta[5])
    rho_matrix = np.eye(2) + rho * np.eye(2)[::-1]  # off-diagonal rho

    mu = theta[:2]
    Sigma = np.outer(s_vec, s_vec) * rho_matrix
    return mu, Sigma
Exemplo n.º 15
    def update(state: WelfordAlgorithmState,
               value: np.DeviceArray) -> WelfordAlgorithmState:
        """Update the M2 matrix using the new value.

            The current state of the Welford Algorithm
        position: jax.numpy.DeviceArray, shape (1,)
            The new sample (typically position of the chain) used to update m2
        mean, m2, sample_size = state
        sample_size = sample_size + 1

        delta = value - mean
        mean = mean + delta / sample_size
        updated_delta = value - mean
        if is_diagonal_matrix:
            new_m2 = m2 + delta * updated_delta
            new_m2 = m2 + np.outer(updated_delta, delta)

        return WelfordAlgorithmState(mean, new_m2, sample_size)
def positional_encoding(seq_len, embed_dim, timescale=10000):
  Returns positional encoding values.
  Assumes seq dimensions are (batch, seq_len, word_space)

  Output shape:
    seq_len x embed_dim

    if embed_dim % 2 != 0:
        raise ValueError("Embedding dimension must be even")

    positions = jnp.arange(seq_len)
    i = jnp.arange(embed_dim // 2)
    angular_frequencies = 1 / jnp.power(timescale, 2 * i / embed_dim)

    angles = jnp.outer(positions, angular_frequencies)
    cosine = jnp.cos(angles)  # seq_len, embed_dim // 2
    sine = jnp.sin(angles)  # seq_len, embed_dim // 2

    pos_enc = jnp.concatenate([cosine, sine], axis=1)

    return pos_enc
Exemplo n.º 17
    def update(state: WelfordAlgorithmState,
               value: np.DeviceArray) -> WelfordAlgorithmState:
        """Update the M2 matrix using the new value.

            The current state of the Welford Algorithm
        value: jax.numpy.DeviceArray, shape (1,)
            The new sample used to update m2
        mean, m2, count = state
        count = count + 1

        delta = value - mean
        mean = mean + delta / count
        updated_delta = value - mean
        if is_diagonal_matrix:
            m2 = m2 + delta * updated_delta
            m2 = m2 + np.outer(delta, updated_delta)

        return WelfordAlgorithmState(mean, m2, count)
Exemplo n.º 18
def house_rightmult(A, v, beta):
    Given the m x n matrix A and the length-n vector v with normalization
    beta such that P = I - beta v otimes dag(v) is the Householder matrix that
    reflects about v, compute AP.

    A:  array_like, shape(M, N)
        Matrix to be multiplied by H.

    v:  array_like, shape(N).
        Householder vector.

    beta: float
        Householder normalization.

    C = AP
    C = A - jnp.outer(A @ v, beta * dag(v))
    return C
Exemplo n.º 19
def mkcovdiag_ASD(len_sc, rho, nxcirc, wvec=None, wwnrm=None):
    #  Eigenvalues of ASD covariance (as diagonalized in Fourier domain)
    #  [cdiag,dcdiag,ddcdiag] = mkcovdiag_ASD(rho,l,nxcirc,wvecsq)
    #  Compute discrete ASD (RBF kernel) eigenspectrum using frequencies in [0, nxcirc].
    #  See mkCov_ASD_factored for more info
    #  INPUT (all python 1d lists!):
    #          len - length scale of ASD kernel (determines smoothness)
    #          rho - maximal prior variance ("overall scale")
    #       nxcirc - number of coefficients to consider for circular boundary
    #         wvec - vector of freq for DFT
    #		wwnrm - vector of freq for DFT (normalized)
    #  OUTPUT:
    #      cdiag [nxcirc x 1] - vector of eigenvalues of C for frequencies in w
    # Note: nxcirc = nx corresponds to having a circular boundary

    # Compute diagonal of ASD covariance matrix
    if wvec is not None:
        wvecsq = np.square(wvec)
        const = np.square(2 * np.pi / nxcirc)  # constant
        ww = wvecsq * const  # effective frequency vector
    elif wwnrm is not None:
        ww = wwnrm
            "please provide either wvec or a normalized wvec into this function"

    cdiag = np.squeeze(
        np.sqrt(2 * np.pi) * rho * len_sc *
        np.exp(-.5 * np.outer(ww, np.square(len_sc))))
    return cdiag
Exemplo n.º 20
 def searchphase(y, x):
     y_t = jnp.outer(y, TPS)
     x_t = jnp.tile(x[:,None], (1, testing_phases))
     snr_t = 10. * jnp.log10(getpower(y_t) / getpower(y_t - x_t))
     return TPS[jnp.argmax(snr_t)]
Exemplo n.º 21
def _gyration_tensor(positions):
    n = positions.shape[0]
    S = np.zeros((3, 3))
    for r in positions:
        S += np.outer(r, r)
    return S / n
Exemplo n.º 22
 def outer(self, tensor_in_1, tensor_in_2):
     return jnp.outer(tensor_in_1, tensor_in_2)
Exemplo n.º 23
        def affine_transform(dist_params, scale, shift, value_transform=None):
            """ implements the "Categorical Algorithm" from https://arxiv.org/abs/1707.06887 """

            # check inputs
            chex.assert_rank([dist_params['logits'], scale, shift],
                             [2, {0, 1}, {0, 1}])
            p = jax.nn.softmax(dist_params['logits'])
            batch_size = p.shape[0]

            if isscalar(scale):
                scale = jnp.full(shape=(batch_size, ),
            if isscalar(shift):
                shift = jnp.full(shape=(batch_size, ),

            chex.assert_shape(p, (batch_size, self.num_bins))
            chex.assert_shape([scale, shift], (batch_size, ))

            if value_transform is None:
                f = f_inv = lambda x: x
                f, f_inv = value_transform

            # variable names correspond to those defined in: https://arxiv.org/abs/1707.06887
            z = self.__atoms
            Vmin, Vmax, Δz = z[0], z[-1], z[1] - z[0]
            Tz = f(jax.vmap(jnp.add)(jnp.outer(scale, f_inv(z)), shift))
            Tz = jnp.clip(Tz, Vmin, Vmax)  # keep values in valid range
            chex.assert_shape(Tz, (batch_size, self.num_bins))

            b = (Tz - Vmin) / Δz  # float in [0, num_bins - 1]
            l = jnp.floor(b).astype(
                'int32')  # noqa: E741   # int in {0, 1, ..., num_bins - 1}
            u = jnp.ceil(b).astype('int32')  # int in {0, 1, ..., num_bins - 1}
            chex.assert_shape([p, b, l, u], (batch_size, self.num_bins))

            m = jnp.zeros_like(p)
            i = jnp.expand_dims(jnp.arange(batch_size), axis=1)  # batch index
            m = jax.ops.index_add(m, (i, l),
                                  p * (u - b),
            m = jax.ops.index_add(m, (i, u),
                                  p * (b - l),
            m = jax.ops.index_add(m, (i, l),
                                  p * (l == u),
            # chex.assert_tree_all_close(jnp.sum(m, axis=1), jnp.ones(batch_size), rtol=1e-6)

            # # The above index trickery is equivalent to:
            # m_alt = onp.zeros((batch_size, self.num_bins))
            # for i in range(batch_size):
            #     for j in range(self.num_bins):
            #         if l[i, j] == u[i, j]:
            #             m_alt[i, l[i, j]] += p[i, j]  # don't split if b[i, j] is an integer
            #         else:
            #             m_alt[i, l[i, j]] += p[i, j] * (u[i, j] - b[i, j])
            #             m_alt[i, u[i, j]] += p[i, j] * (b[i, j] - l[i, j])
            # chex.assert_tree_all_close(m, m_alt, rtol=1e-6)
            return {'logits': jnp.log(jnp.maximum(m, 1e-16))}
Exemplo n.º 24
        optimizers = [
        RNG, key = random.split(RNG)
        t_cov, t_mu = orig_man.rand(key)
        RNG, key = random.split(RNG)
        data = random.multivariate_normal(key,
                                          shape=(N, ))
        s_mu = jnp.mean(data, axis=0)
        s_cov = jnp.dot((data - s_mu).T, data - s_mu) / N

        MLE_rep = jnp.append(jnp.append(s_cov + jnp.outer(s_mu, s_mu),
                             jnp.array([jnp.append(s_mu, 1)]).T,
        if chol:
            MLE_chol = jnp.linalg.cholesky(MLE_rep)
            MLE_chol = MLE_chol.T[~(MLE_chol.T == 0.)].ravel()

        def nloglik(X):
            y = jnp.concatenate([data.T, jnp.ones(shape=(1, N))], axis=0)
            datapart = jnp.trace(jnp.linalg.solve(X, jnp.matmul(y, y.T)))
            return 0.5 * (N * jnp.linalg.slogdet(X)[1] + datapart)

        if chol:
def compute_OBC_energy_vectorized(
    """Compute GBSA-OBC energy from a distance matrix"""
    N = len(radii)
    eye = np.eye(N, dtype=distance_matrix.dtype)
    r = distance_matrix + eye  # so I don't have divide-by-zero nonsense
    or1 = radii.reshape((N, 1)) - offset
    or2 = radii.reshape((1, N)) - offset
    sr2 = scales.reshape((1, N)) * or2

    L = np.maximum(or1, abs(r - sr2))
    U = r + sr2
    I = step(r + sr2 - or1) * 0.5 * (1 / L - 1 / U + 0.25 * (r - sr2**2 / r) *
                                     (1 / (U**2) - 1 /
                                      (L**2)) + 0.5 * np.log(L / U) / r)

    I -= np.diag(np.diag(I))
    I = np.sum(I, axis=1)

    # okay, next compute born radii
    offset_radius = radii - offset
    psi = I * offset_radius
    psi_coefficient = 0.8
    psi2_coefficient = 0
    psi3_coefficient = 2.909125

    psi_term = (psi_coefficient * psi) + (psi2_coefficient *
                                          psi**2) + (psi3_coefficient * psi**3)

    B = 1 / (1 / offset_radius - np.tanh(psi_term) / radii)

    # finally, compute the three energy terms
    E = 0.0

    # single particle
    E += np.sum(surface_tension * (radii + 0.14)**2 * (radii / B)**6)
    E += np.sum(-0.5 * screening *
                (1 / solute_dielectric - 1 / solvent_dielectric) * charges**2 /

    # particle pair
    f = np.sqrt(r**2 + np.outer(B, B) * np.exp(-r**2 / (4 * np.outer(B, B))))
    charge_products = np.outer(charges, charges)

    E += np.sum(
        np.triu(-screening * (1 / solute_dielectric - 1 / solvent_dielectric) *
                charge_products / f,

    return E
Exemplo n.º 26
Arquivo: ops.py Projeto: dfm/celeriac
def _matmul_impl(state: Carry, data: MatmulData) -> Tuple[Carry, Array]:
    (Fp, Vp, Yp) = state
    Vn, Pn, Yn = data
    Fn = _pdot(Pn, Fp + jnp.outer(Vp, Yp))
    return (Fn, Vn, Yn), Fn
Exemplo n.º 27
Arquivo: ops.py Projeto: dfm/celeriac
def _solve_impl(state: Carry, data: Data) -> Tuple[Carry, Array]:
    Fp, Wp, Zp = state
    Un, Wn, Pn, Yn = data
    Fn = _pdot(Pn, Fp + jnp.outer(Wp, Zp))
    Zn = Yn - Un @ Fn
    return (Fn, Wn, Zn), Zn
Exemplo n.º 28
    def __init__(self, n_samp, seed, outdtype, pardtype, holomorphic):

        self.dtype = outdtype

        self.target = {
            "a": jnp.array([[[0j], [0j]]], dtype=jnp.complex128),
            "b": jnp.array(0, dtype=jnp.float64),
            "c": jnp.array(0j, dtype=jnp.complex64),

        if pardtype is None:  # mixed precision as above
            self.target = jax.tree_map(
                lambda x: astype_unsafe(x, pardtype),

        k = jax.random.PRNGKey(seed)
        k1, k2, k3, k4, k5 = jax.random.split(k, 5)

        self.samples = jax.random.normal(k1, (n_samp, 2))
        self.w = jax.random.normal(k2, (n_samp,), self.dtype).astype(
        )  # TODO remove astype once its fixed in jax
        self.params = tree_random_normal_like(k3, self.target)
        self.v = tree_random_normal_like(k4, self.target)
        self.grad = tree_random_normal_like(k5, self.target)

        if holomorphic:

            @partial(jax.vmap, in_axes=(None, 0))
            def f(params, x):
                return astype_unsafe(
                    params["a"][0][0][0] * x[0]
                    + params["b"] * x[1]
                    + params["c"] * (x[0] * x[1])
                    + jnp.sin(x[1] * params["a"][0][1][0])
                    * jnp.cos(x[0] * params["b"] + 1j)
                    * params["c"],


            @partial(jax.vmap, in_axes=(None, 0))
            def f(params, x):
                return astype_unsafe(
                    params["a"][0][0][0].conjugate() * x[0]
                    + params["b"] * x[1]
                    + params["c"] * (x[0] * x[1])
                    + jnp.sin(x[1] * params["a"][0][1][0])
                    * jnp.cos(x[0] * params["b"].conjugate() + 1j)
                    * params["c"].conjugate(),

        self.f = f

        self.params_real_flat = tree_toreal_flat(self.params)
        self.grad_real_flat = tree_toreal_flat(self.grad)
        self.v_real_flat = tree_toreal_flat(self.v)
        self.ok_real = self.grads_real(self.params_real_flat, self.samples)
        self.okmean_real = self.ok_real.mean(axis=0)
        self.dok_real = self.ok_real - self.okmean_real
        self.S_real = (
            self.dok_real.conjugate().transpose() @ self.dok_real / n_samp
        self.scale = jnp.sqrt(self.S_real.diagonal())
        self.S_real_scaled = self.S_real / (jnp.outer(self.scale, self.scale))
Exemplo n.º 29
def normal_expected_stats(normal):
    mu, sigma = normal_nat_to_std(*normal)
    t1 = np.outer(mu, mu) + sigma
    t2 = mu
    return t1, t2
Exemplo n.º 30
def niw_std_to_nat(mu, kappa, psi, nu):
    n1 = kappa * np.outer(mu, mu) + psi
    n2 = kappa * mu
    n3 = kappa
    n4 = nu + psi.shape[0] + 2
    return n1, n2, n3, n4