Exemple #1
0
 def phif(s):
     pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1)
     pmag = np.sqrt(pmagsq)
     phipartial = np.reciprocal(pmag)
     singular = np.any(np.equal(-s, lambar), axis=-1)
     phipartial = np.where(singular, 0., phipartial)
     phi = phipartial - np.reciprocal(trust_radius)
     return phi
Exemple #2
0
def minimum_volume_enclosing_ellipsoid(points,
                                       tol,
                                       init_u=None,
                                       return_u=False):
    """
    Performs the algorithm of
    MINIMUM VOLUME ENCLOSING ELLIPSOIDS
    NIMA MOSHTAGH
    psuedo-code here:
    https://stackoverflow.com/questions/1768197/bounding-ellipse

    Args:
        points: [N, D]
    """
    N, D = points.shape
    Q = jnp.concatenate([points, jnp.ones([N, 1])], axis=1)  # N,D+1

    def body(state):
        (count, err, u) = state
        V = Q.T @ jnp.diag(u) @ Q  # D+1, D+1
        # g[i] = Q[i,j].V^-1_jk.Q[i,k]
        g = vmap(lambda q: q @ jnp.linalg.solve(V, q))(Q)  # difference
        # jnp.diag(Q @ jnp.linalg.solve(V, Q.T))
        j = jnp.argmax(g)
        g_max = g[j]

        step_size = \
            (g_max - D - 1) / ((D + 1) * (g_max - 1))
        search_direction = jnp.where(jnp.arange(N) == j, 1. - u, -u)
        new_u = u + step_size * search_direction
        # new_u = (1. - step_size)*u
        new_u = jnp.where(
            jnp.arange(N) == j, u + step_size * (1. - u), u * (1. - step_size))
        new_err = jnp.linalg.norm(u - new_u)
        return (count + 1, new_err, new_u)

    if init_u is None:
        init_u = jnp.ones(N) / N
    (count, err,
     u) = while_loop(lambda state: state[1] > tol * jnp.linalg.norm(init_u),
                     body, (0, jnp.inf, init_u))
    U = jnp.diag(u)
    PU = (points.T @ u)  # D, N
    A = jnp.reciprocal(D) * jnp.linalg.pinv(points.T @ U @ points -
                                            PU[:, None] @ PU[None, :])
    c = PU
    W, Q, Vh = jnp.linalg.svd(A)
    radii = jnp.reciprocal(jnp.sqrt(Q))
    rotation = Vh.conj().T
    if return_u:
        return c, radii, rotation, u
    return c, radii, rotation
Exemple #3
0
    def init_fn(z, rng, step_size=1.0, inverse_mass_matrix=None, mass_matrix_size=None):
        """
        :param z: Initial position of the integrator.
        :param jax.random.PRNGKey rng: Random key to be used as the source of randomness.
        :param float step_size: Initial step size.
        :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
            inverse of mass matrix will be an identity matrix with size is decided
            by the argument `mass_matrix_size`.
        :param int mass_matrix_size: Size of the mass matrix.
        :return: initial state of the adapt scheme.
        """
        rng, rng_ss = random.split(rng)
        if inverse_mass_matrix is None:
            assert mass_matrix_size is not None
            if dense_mass:
                inverse_mass_matrix = np.identity(mass_matrix_size)
            else:
                inverse_mass_matrix = np.ones(mass_matrix_size)
            mass_matrix_sqrt = inverse_mass_matrix
        else:
            if dense_mass:
                mass_matrix_sqrt = cholesky_inverse(inverse_mass_matrix)
            else:
                mass_matrix_sqrt = np.sqrt(np.reciprocal(inverse_mass_matrix))

        if adapt_step_size:
            step_size = find_reasonable_step_size(inverse_mass_matrix, z, rng_ss, step_size)
        ss_state = ss_init(np.log(10 * step_size))

        mm_state = mm_init(inverse_mass_matrix.shape[-1])

        window_idx = 0
        return AdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
                          ss_state, mm_state, window_idx, rng)
Exemple #4
0
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)
Exemple #5
0
def test_welford_covariance(jitted, diagonal, regularize):
    with optional(jitted,
                  disable_jit()), optional(jitted,
                                           control_flow_prims_disabled()):
        np.random.seed(0)
        loc = np.random.randn(3)
        a = np.random.randn(3, 3)
        target_cov = np.matmul(a, a.T)
        x = np.random.multivariate_normal(loc, target_cov, size=(2000, ))
        x = device_put(x)

        @jit
        def get_cov(x):
            wc_init, wc_update, wc_final = welford_covariance(
                diagonal=diagonal)
            wc_state = wc_init(3)
            wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val),
                                 wc_state)
            cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize)
            return cov, cov_inv_sqrt

        cov, cov_inv_sqrt = get_cov(x)

        if diagonal:
            diag_cov = jnp.diagonal(target_cov)
            assert_allclose(cov, diag_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.sqrt(jnp.reciprocal(diag_cov)),
                            rtol=0.06)
        else:
            assert_allclose(cov, target_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.linalg.cholesky(jnp.linalg.inv(cov)),
                            rtol=0.06)
Exemple #6
0
 def final_fn(state, regularize=False):
     """
     :param state: Current state of the scheme.
     :param bool regularize: Whether to adjust diagonal for numerical stability.
     :return: a triple of estimated covariance, the square root of precision, and
         the inverse of that square root.
     """
     mean, m2, n = state
     # XXX it is not necessary to check for the case n=1
     cov = m2 / (n - 1)
     if regularize:
         # Regularization from Stan
         scaled_cov = (n / (n + 5)) * cov
         shrinkage = 1e-3 * (5 / (n + 5))
         if diagonal:
             cov = scaled_cov + shrinkage
         else:
             cov = scaled_cov + shrinkage * jnp.identity(mean.shape[0])
     if jnp.ndim(cov) == 2:
         # copy the implementation of distributions.util.cholesky_of_inverse here
         tril_inv = jnp.swapaxes(
             jnp.linalg.cholesky(cov[..., ::-1, ::-1])[..., ::-1, ::-1], -2,
             -1)
         identity = jnp.identity(cov.shape[-1])
         cov_inv_sqrt = solve_triangular(tril_inv, identity, lower=True)
     else:
         tril_inv = jnp.sqrt(cov)
         cov_inv_sqrt = jnp.reciprocal(tril_inv)
     return cov, cov_inv_sqrt, tril_inv
Exemple #7
0
def pow_right(y, z, ildj_):
    # x ** y = z
    # x = f^-1(z) = z ** (1 / y)
    # grad(f^-1)(z) = 1 / y * z ** (1 / y - 1)
    # log(grad(f^-1)(z)) = (1 / y - 1)log(z) - log(y)
    y_inv = np.reciprocal(y)
    return lax.pow(z, y_inv), ildj_ + (y_inv - 1.) * np.log(z) - np.log(y)
Exemple #8
0
def _sample_momentum(unpack_fn, inverse_mass_matrix, rng):
    if inverse_mass_matrix.ndim == 1:
        r = dist.norm(0., np.sqrt(
            np.reciprocal(inverse_mass_matrix))).rvs(random_state=rng)
        return unpack_fn(r)
    elif inverse_mass_matrix.ndim == 2:
        raise NotImplementedError
Exemple #9
0
def restricted_mp2(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=0):
    nelectrons = int(jnp.sum(nuclear_charges)) - charge
    ndocc = nelectrons // 2
    E_scf, C, eps, G = restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=deriv_order, return_aux_data=True)

    nvirt = G.shape[0] - ndocc
    nbf = G.shape[0]

    G = partial_tei_transformation(G, C[:,:ndocc],C[:,ndocc:],C[:,:ndocc],C[:,ndocc:])

    # Create tensor dim (occ,vir,occ,vir) of all possible orbital energy denominators
    # Partial tei transformation is super efficient, it is this part that is bad.
    eps_occ, eps_vir = eps[:ndocc], eps[ndocc:]
    e_denom = jnp.reciprocal(eps_occ.reshape(-1, 1, 1, 1) - eps_vir.reshape(-1, 1, 1) + eps_occ.reshape(-1, 1) - eps_vir)

    # Tensor contraction algo 
    #mp2_correlation = jnp.einsum('iajb,iajb,iajb->', G, G, e_denom) +\
    #                  jnp.einsum('iajb,iajb,iajb->', G - jnp.transpose(G, (0,3,2,1)), G, e_denom)
    #mp2_total_energy = mp2_correlation + E_scf
    #return E_scf + mp2_correlation

    # Loop algo (lower memory, but tei transform is the memory bottleneck)
    # Create all combinations of four loop variables to make XLA compilation easier
    indices = cartesian_product(jnp.arange(ndocc),jnp.arange(ndocc),jnp.arange(nvirt),jnp.arange(nvirt))
    with loops.Scope() as s:
      s.mp2_correlation = 0.
      for idx in s.range(indices.shape[0]):
        i,j,a,b = indices[idx]
        s.mp2_correlation += G[i, a, j, b] * (2 * G[i, a, j, b] - G[i, b, j, a]) * e_denom[i,a,j,b]
      return E_scf + s.mp2_correlation
Exemple #10
0
def _sample_momentum(unpack_fn, inverse_mass_matrix, rng):
    if inverse_mass_matrix.ndim == 1:
        r = dist.Normal(0., np.sqrt(
            np.reciprocal(inverse_mass_matrix))).sample(rng)
        return unpack_fn(r)
    elif inverse_mass_matrix.ndim == 2:
        raise NotImplementedError
Exemple #11
0
    def body(state):
        p_k = -(state.H_k @ state.g_k)
        line_search_results = line_search(value_and_grad, state.x_k, p_k, old_fval=state.f_k, gfk=state.g_k,
                                          maxiter=ls_maxiter)
        state = state._replace(nfev=state.nfev + line_search_results.nfev,
                               ngev=state.ngev + line_search_results.ngev,
                               failed=line_search_results.failed,
                               ls_status=line_search_results.status)
        s_k = line_search_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = line_search_results.f_k
        g_kp1 = line_search_results.g_k
        # print(g_kp1)
        y_k = g_kp1 - state.g_k
        rho_k = jnp.reciprocal(y_k @ s_k)

        sy_k = s_k[:, None] * y_k[None, :]
        w = jnp.eye(d) - rho_k * sy_k
        H_kp1 = jnp.where(jnp.isfinite(rho_k),
                          jnp.linalg.multi_dot([w, state.H_k, w.T]) + rho_k * s_k[:, None] * s_k[None, :], state.H_k)

        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        state = state._replace(converged=converged,
                               k=state.k + 1,
                               x_k=x_kp1,
                               f_k=f_kp1,
                               g_k=g_kp1,
                               H_k=H_kp1
                               )

        return state
        def quantized_layernorm(x):
            prec = hparams.quant_hparams.prec
            fp_quant = QuantOps.FloatQuant(is_scaled=False, fp_spec=prec)
            quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant,
                                                     bounds=None)

            def to_quantized(x):
                return quant_ops.to_quantized(x, dtype=dtype)

            # If epsilon is too small to represent in the quantized format, we set it
            # to the minimal representative non-zero value to avoid the possibility of
            # dividing by zero.
            fp_bounds = quantization.fp_cast.get_bounds(
                prec.exp_min, prec.exp_max, prec.sig_bits)
            epsilon = max(self.epsilon, fp_bounds.flush_to_zero_bound)
            quantized_epsilon = to_quantized(jnp.array(epsilon, dtype=dtype))

            # If the reciprocal of the quantized number of features is too small to
            # represent in the quantized format, we set it to the minimal
            # representative nonzero value so that the mean and variance are not
            # trivially 0.
            num_features_quantized = to_quantized(
                jnp.array(num_features, dtype=dtype))
            num_features_recip_quantized = to_quantized(
                jnp.reciprocal(num_features_quantized))
            num_features_recip_quantized = jax.lax.cond(
                jax.lax.eq(num_features_recip_quantized,
                           0.0), lambda _: quantized_epsilon,
                lambda _: num_features_recip_quantized, None)

            x_quantized = to_quantized(x)
            x_sum_quantized_reduction = quantization.quantized_sum(
                x_quantized,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sum = to_quantized(x_sum_quantized_reduction)
            mean = to_quantized(x_sum * num_features_recip_quantized)
            x_minus_mean = to_quantized(x - mean)
            x_sq = to_quantized(lax.square(x_minus_mean))
            x_sq_sum_quantized_reduction = quantization.quantized_sum(
                x_sq,
                axis=-1,
                keepdims=True,
                prec=hparams.quant_hparams.reduction_prec)
            x_sq_sum = to_quantized(x_sq_sum_quantized_reduction)
            var = to_quantized(x_sq_sum * num_features_recip_quantized)
            # Prevent division by zero.
            var_plus_epsilon = to_quantized(var + quantized_epsilon)
            mul = to_quantized(lax.rsqrt(var_plus_epsilon))
            if self.use_scale:
                quantized_scale_param = to_quantized(scale_param)
                mul = to_quantized(mul * quantized_scale_param)
            y = to_quantized(x_minus_mean * mul)
            if self.use_bias:
                quantized_bias_param = to_quantized(bias_param)
                y = to_quantized(y + quantized_bias_param)
            return y.astype(self.dtype)
Exemple #13
0
    def body_fun(state: LBFGSResults):
        # find search direction
        p_k = _two_loop_recursion(state)

        # line search
        ls_results = line_search(
            f=fun,
            xk=state.x_k,
            pk=p_k,
            old_fval=state.f_k,
            gfk=state.g_k,
            maxiter=maxls,
        )

        # evaluate at next iterate
        s_k = ls_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = ls_results.f_k
        g_kp1 = ls_results.g_k
        y_k = g_kp1 - state.g_k
        rho_k_inv = jnp.real(_dot(y_k, s_k))
        rho_k = jnp.reciprocal(rho_k_inv)
        gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))

        # replacements for next iteration
        status = 0
        status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
        status = jnp.where(state.ngev >= maxgrad, 3, status)  # type: ignore
        status = jnp.where(state.nfev >= maxfun, 2, status)  # type: ignore
        status = jnp.where(state.k >= maxiter, 1, status)  # type: ignore
        status = jnp.where(ls_results.failed, 5, status)

        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        # TODO(jakevdp): use a fixed-point procedure rather than type-casting?
        state = state._replace(
            converged=converged,
            failed=(status > 0) & (~converged),
            k=state.k + 1,
            nfev=state.nfev + ls_results.nfev,
            ngev=state.ngev + ls_results.ngev,
            x_k=x_kp1.astype(state.x_k.dtype),
            f_k=f_kp1.astype(state.f_k.dtype),
            g_k=g_kp1.astype(state.g_k.dtype),
            s_history=_update_history_vectors(history=state.s_history,
                                              new=s_k),
            y_history=_update_history_vectors(history=state.y_history,
                                              new=y_k),
            rho_history=_update_history_scalars(history=state.rho_history,
                                                new=rho_k),
            gamma=gamma,
            status=jnp.where(converged, 0, status),
            ls_status=ls_results.status,
        )

        return state
Exemple #14
0
def gaussian_euclidean_metric(
    inverse_mass_matrix: np.DeviceArray,
) -> Tuple[Callable, Callable]:
    """Emulate dynamics on an Euclidean Manifold [1]_ for vanilla Hamiltonian
    Monte Carlo with a standard gaussian as the conditional probability density
    :math:`\\pi(momentum|position)`.

    References
    ----------
    .. [1]: Betancourt, Michael. "A general metric for Riemannian manifold
            Hamiltonian Monte Carlo." International Conference on Geometric Science of
            Information. Springer, Berlin, Heidelberg, 2013.
    """

    ndim = np.ndim(inverse_mass_matrix)
    shape = np.shape(inverse_mass_matrix)[:1]

    if ndim == 1:  # diagonal mass matrix

        mass_matrix_sqrt = np.sqrt(np.reciprocal(inverse_mass_matrix))

        @jax.jit
        def momentum_generator(rng_key: jax.random.PRNGKey) -> np.DeviceArray:
            std = jax.random.normal(rng_key, shape)
            p = np.multiply(std, mass_matrix_sqrt)
            return p

        @jax.jit
        def kinetic_energy(momentum: np.DeviceArray) -> float:
            velocity = np.multiply(inverse_mass_matrix, momentum)
            return 0.5 * np.dot(velocity, momentum)

        return momentum_generator, kinetic_energy

    elif ndim == 2:

        mass_matrix_sqrt = cholesky_triangular(inverse_mass_matrix)

        @jax.jit
        def momentum_generator(rng_key: jax.random.PRNGKey) -> np.DeviceArray:
            std = jax.random.normal(rng_key, shape)
            p = np.dot(std, mass_matrix_sqrt)
            return p

        @jax.jit
        def kinetic_energy(momentum: np.DeviceArray) -> float:
            velocity = np.matmul(inverse_mass_matrix, momentum)
            return 0.5 * np.dot(velocity, momentum)

        return momentum_generator, kinetic_energy

    else:
        raise ValueError(
            "The mass matrix has the wrong number of dimensions: "
            + "expected 1 or 2, got {}.".format(np.ndim(inverse_mass_matrix))
        )
Exemple #15
0
def integer_pow_inverse(z, *, y):
    """Inverse for `integer_pow_p` primitive."""
    if y == 0:
        raise ValueError('Cannot invert raising to a value to the 0-th power.')
    elif y == 1:
        return z
    elif y == -1:
        return np.reciprocal(z)
    elif y == 2:
        return np.sqrt(z)
    return lax.pow(z, 1. / y)
 def precision_matrix(self):
     # We use "Woodbury matrix identity" to take advantage of low rank form::
     #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
     # where :math:`C` is the capacitance matrix.
     Wt_Dinv = (np.swapaxes(self.cov_factor, -1, -2)
                / np.expand_dims(self.cov_diag, axis=-2))
     A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True)
     # TODO: find a better solution to create a diagonal matrix
     inverse_cov_diag = np.reciprocal(self.cov_diag)
     diag_embed = inverse_cov_diag[..., np.newaxis] * np.identity(self.loc.shape[-1])
     return diag_embed - np.matmul(np.swapaxes(A, -1, -2), A)
def scaleSqFromModelParsSingleMu(A, e, M, W, etas, binCenters1, good_idx):

    coeffe1 = binCenters1[..., 1]
    coeffM1 = binCenters1[..., 0]

    #term1 = A[good_idx[0]]-e[good_idx[0]]*coeffe1+M[good_idx[0]]*coeffM1 + W[good_idx[0]]*np.abs(coeffM1)
    term1 = (A[good_idx[0]] - 1.) + np.reciprocal(
        1 + e[good_idx[0]] *
        coeffe1) + M[good_idx[0]] * coeffM1 + W[good_idx[0]] * np.abs(coeffM1)

    scaleSq = np.square(1. - term1)

    return scaleSq
Exemple #18
0
def kl_divergence(p, q):
    # From https://arxiv.org/abs/1605.06197 Formula (12)
    a, b = p.concentration1, p.concentration0
    alpha, beta = q.concentration1, q.concentration0
    b_reciprocal = jnp.reciprocal(b)
    a_b = a * b
    t1 = (alpha / a - 1) * (jnp.euler_gamma + digamma(b) + b_reciprocal)
    t2 = jnp.log(a_b) + betaln(alpha, beta) + (b_reciprocal - 1)
    a_ = jnp.expand_dims(a, -1)
    b_ = jnp.expand_dims(b, -1)
    a_b_ = jnp.expand_dims(a_b, -1)
    m = jnp.arange(1, p.KL_KUMARASWAMY_BETA_TAYLOR_ORDER + 1)
    t3 = (beta - 1) * b * (jnp.exp(betaln(m / a_, b_)) / (m + a_b_)).sum(-1)
    return t1 + t2 + t3
Exemple #19
0
def MSAWeight_PB(msa):
    gap_idx = msa.abc.charmap['-']
    q = msa.abc.q
    ax = msa.ax
    (N, L) = ax.shape

    ## step 1: get counts:

    c = np.sum(msa.ax_1hot, axis=0)

    # set gap counts to 0
    c = index_update(c, index[:, gap_idx], 0)

    # get N x L array with count value for corresponding residue in alignment
    # first, get  N x L "column id" array (convenient for vmap)
    # col_id[n,i] = i
    col_id = np.int16(np.tensordot(np.ones(N), np.arange(L), axes=0))
    # ax_c[n, i] = c[i, ax[n,i]]
    ax_c = Get_Henikoff_Counts_Residue(col_id, ax, c)

    ## step 2: get number of unique characters in each column
    r = np.float32(np.sum(np.array(c > 0), axis=1))

    # transform r from Lx1 array to NxL array, where r2[n,i] = r[i])
    # will allow for easy elementwise operations with ax_c
    r2 = np.tensordot(np.ones(N), r, axes=0)

    ## step 3: get ungapped seq lengths
    nongap = np.array(ax != gap_idx)
    l = np.float32(np.sum(nongap, axis=1))

    ## step 4: calculate unnormalized weights
    # get array of main terms in Henikoff sum
    #wgt_un[n,i] = 1 / (r_[i] * c[i, ax[n,i] ])
    wgt_un = np.reciprocal(np.multiply(ax_c, r2))

    # set all terms involving  gap to zero
    wgt_un = np.nan_to_num(np.multiply(wgt_un, nongap))

    # sum accoss all positions to get prelim unnormalized weight for each sequence
    wgt_un = np.sum(wgt_un, axis=1)

    # divide by gapless sequence length
    wgt_un = np.divide(wgt_un, l)

    # step 4: Normalize sequence wieghts
    wgt = (wgt_un * np.float32(N)) / np.sum(wgt_un)
    msa.wgt = wgt

    return
def sigmaSqFromModelParsSingleMu(a, b, c, d, etas, binCenters1, good_idx):

    #compute sigma from physics parameters

    pt2 = binCenters1[..., 2]
    L2 = binCenters1[..., 3]
    #corr = binCenters1[...,4]
    invpt2 = binCenters1[..., 4]

    sigmaSq = a[good_idx[0]] * L2 + c[good_idx[0]] * pt2 * np.square(L2) + b[
        good_idx[0]] * L2 * np.reciprocal(1 + d[good_idx[0]] * invpt2 / L2)
    #sigmaSq = a[good_idx[0]]*L2 + c[good_idx[0]]*pt2*np.square(L2) + corr

    return sigmaSq
    def var_exp(self, freqs, Y_obs, sigma, amp, mu, gamma):
        """
        Computes variational expectation
        Args:
            freqs: [Nf]
            Y_obs: [Nf]
            sigma: [Nf]
            amp: [Nf]
            mu: [M]
            gamma: [M]

        Returns: scalar

        """
        f = self._phase_basis(self.freqs)  # Nf,M
        Nf = freqs.size
        Sigma_real = jnp.square(sigma[:Nf])
        Sigma_imag = jnp.square(sigma[Nf:])
        Yreal = Y_obs[:Nf]
        Yimag = Y_obs[Nf:]
        a = jnp.reciprocal(Sigma_real)
        b = jnp.reciprocal(Sigma_imag)
        constant = -Nf * jnp.log(2. * jnp.pi)
        logdet = -jnp.sum(jnp.log(sigma))

        phi = jnp.dot(f, mu)
        theta = jnp.dot(jnp.square(f), jnp.square(gamma))

        exp_cos = jnp.exp(-0.5 * theta) * jnp.cos(phi)
        exp_sin = jnp.exp(-0.5 * theta) * jnp.sin(phi)
        exp_cos2 = 0.5 * (jnp.exp(-2. * theta) * jnp.cos(2. * phi) + 1.)

        negtwo_maha = a * (jnp.square(Yreal) - 2. * amp * Yreal * exp_cos) \
                      + b * (jnp.square(Yimag) - 2. * amp * Yimag * exp_sin) \
                      + ((a - b) * exp_cos2 + b) * jnp.square(amp)

        return constant + logdet - 0.5 * jnp.sum(negtwo_maha)
Exemple #22
0
    def init_fn(z_info,
                rng_key,
                step_size=1.0,
                inverse_mass_matrix=None,
                mass_matrix_size=None):
        """
        :param IntegratorState z_info: The initial integrator state.
        :param jax.random.PRNGKey rng_key: Random key to be used as the source of randomness.
        :param float step_size: Initial step size.
        :param inverse_mass_matrix: Inverse of the initial mass matrix. If ``None``,
            inverse of mass matrix will be an identity matrix with size is decided
            by the argument `mass_matrix_size`.
        :param int mass_matrix_size: Size of the mass matrix.
        :return: initial state of the adapt scheme.
        """
        rng_key, rng_key_ss = random.split(rng_key)
        if inverse_mass_matrix is None:
            assert mass_matrix_size is not None
            if dense_mass:
                inverse_mass_matrix = jnp.identity(mass_matrix_size)
            else:
                inverse_mass_matrix = jnp.ones(mass_matrix_size)
            mass_matrix_sqrt = mass_matrix_sqrt_inv = inverse_mass_matrix
        else:
            if dense_mass:
                mass_matrix_sqrt_inv = jnp.swapaxes(
                    jnp.linalg.cholesky(
                        inverse_mass_matrix[..., ::-1, ::-1])[..., ::-1, ::-1],
                    -2, -1)
                identity = jnp.identity(inverse_mass_matrix.shape[-1])
                mass_matrix_sqrt = solve_triangular(mass_matrix_sqrt_inv,
                                                    identity,
                                                    lower=True)
            else:
                mass_matrix_sqrt_inv = jnp.sqrt(inverse_mass_matrix)
                mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt_inv)

        if adapt_step_size:
            step_size = find_reasonable_step_size(step_size,
                                                  inverse_mass_matrix, z_info,
                                                  rng_key_ss)
        ss_state = ss_init(jnp.log(10 * step_size))

        mm_state = mm_init(inverse_mass_matrix.shape[-1])

        window_idx = 0
        return HMCAdaptState(step_size, inverse_mass_matrix, mass_matrix_sqrt,
                             mass_matrix_sqrt_inv, ss_state, mm_state,
                             window_idx, rng_key)
 def unquantized_layernorm(x):
     num_features_recip = jnp.reciprocal(num_features)
     x_sum = jnp.sum(x, axis=-1, keepdims=True)
     mean = x_sum * num_features_recip
     x_minus_mean = x - mean
     x_sq = lax.square(x_minus_mean)
     x_sq_sum = jnp.sum(x_sq, axis=-1, keepdims=True)
     var = x_sq_sum * num_features_recip
     var_plus_epsilon = var + self.epsilon
     mul = lax.rsqrt(var_plus_epsilon)
     if self.use_scale:
         mul = mul * scale_param
     y = x_minus_mean * mul
     if self.use_bias:
         y = y + bias_param
     return y.astype(self.dtype)
Exemple #24
0
 def body(state):
     (i, _, key, done, _) = state
     key, accept_key, sample_key, select_key = random.split(key, 4)
     k = random.categorical(select_key, log_p)
     mu_k = mu[k, :]
     radii_k = radii[k, :]
     rotation_k = rotation[k, :, :]
     u_test = sample_ellipsoid(sample_key,
                               mu_k,
                               radii_k,
                               rotation_k,
                               unit_cube_constraint=unit_cube_constraint)
     inside = vmap(lambda mu, radii, rotation: point_in_ellipsoid(
         u_test, mu, radii, rotation))(mu, radii, rotation)
     n_intersect = jnp.sum(inside)
     done = (random.uniform(accept_key) < jnp.reciprocal(n_intersect))
     return (i + 1, k, key, done, u_test)
Exemple #25
0
def ellipsoid_params(C):
    """
    If C satisfies the sectional inequality,

    (x - mu)^T C (x - mu) <= 1

    then this returns the radius and rotation matrix of the ellipsoid.

    Args:
        C: [D,D]

    Returns: radii [D] rotation [D,D]

    """
    W, Q, Vh = jnp.linalg.svd(C)
    radii = jnp.reciprocal(jnp.sqrt(Q))
    radii = jnp.where(jnp.isnan(radii), 0., radii)
    rotation = Vh.conj().T
    return radii, rotation
Exemple #26
0
def debug_mvee():
    import pylab as plt

    n = random.normal(random.PRNGKey(0), (10000,2))
    n = n /jnp.linalg.norm(n, axis=1, keepdims=True)
    angle = jnp.arctan2(n[:,1], n[:,0])
    plt.hist(angle, bins=100)
    plt.show()
    N = 120
    D = 2
    points = random.uniform(random.PRNGKey(0), (N, D))

    from jax import disable_jit
    with disable_jit():
        center, radii, rotation = minimum_volume_enclosing_ellipsoid(points, 0.01)

    plt.hist(jnp.linalg.norm((rotation.T @ (points.T - center[:, None])) / radii[:, None], axis=0))
    plt.show()
    print(center, radii, rotation)
    plt.scatter(points[:, 0], points[:, 1])
    theta = jnp.linspace(0., jnp.pi*2, 100)
    ellipsis = center[:, None] + rotation @ jnp.stack([radii[0]*jnp.cos(theta), radii[1]*jnp.sin(theta)], axis=0)
    plt.plot(ellipsis[0,:], ellipsis[1,:])

    for i in range(1000):
        y = sample_ellipsoid(random.PRNGKey(i), center, radii, rotation)
        plt.scatter(y[0], y[1])



    C = jnp.linalg.pinv(jnp.cov(points, rowvar=False, bias=True))
    p = (N - D - 1)/N
    def q(p):
        return p + p**2/(4.*(D-1))
    C = C / q(p)
    c = jnp.mean(points, axis=0)
    W, Q, Vh = jnp.linalg.svd(C)
    radii = jnp.reciprocal(jnp.sqrt(Q))
    rotation = Vh.conj().T
    ellipsis = c[:, None] + rotation @ jnp.stack([radii[0] * jnp.cos(theta), radii[1] * jnp.sin(theta)], axis=0)
    plt.plot(ellipsis[0, :], ellipsis[1, :])

    plt.show()
Exemple #27
0
  def body_fun(state):
    p_k = -_dot(state.H_k, state.g_k)
    line_search_results = line_search(
        fun,
        state.x_k,
        p_k,
        old_fval=state.f_k,
        old_old_fval=state.old_old_fval,
        gfk=state.g_k,
        maxiter=line_search_maxiter,
    )
    state = state._replace(
        nfev=state.nfev + line_search_results.nfev,
        ngev=state.ngev + line_search_results.ngev,
        failed=line_search_results.failed,
        line_search_status=line_search_results.status,
    )
    s_k = line_search_results.a_k * p_k
    x_kp1 = state.x_k + s_k
    f_kp1 = line_search_results.f_k
    g_kp1 = line_search_results.g_k
    y_k = g_kp1 - state.g_k
    rho_k = jnp.reciprocal(_dot(y_k, s_k))

    sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :]
    w = jnp.eye(d, dtype=rho_k.dtype) - rho_k * sy_k
    H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w)
             + rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :])
    H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
    converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

    state = state._replace(
        converged=converged,
        k=state.k + 1,
        x_k=x_kp1,
        f_k=f_kp1,
        g_k=g_kp1,
        H_k=H_kp1,
        old_old_fval=state.f_k,
    )
    return state
Exemple #28
0
 def final_fn(state, regularize=False):
     """
     :param state: Current state of the scheme.
     :param bool regularize: Whether to adjust diagonal for numerical stability.
     :return: a pair of estimated covariance and the square root of precision.
     """
     mean, m2, n = state
     # XXX it is not necessary to check for the case n=1
     cov = m2 / (n - 1)
     if regularize:
         # Regularization from Stan
         scaled_cov = (n / (n + 5)) * cov
         shrinkage = 1e-3 * (5 / (n + 5))
         if diagonal:
             cov = scaled_cov + shrinkage
         else:
             cov = scaled_cov + shrinkage * np.identity(mean.shape[0])
     if np.ndim(cov) == 2:
         cov_inv_sqrt = cholesky_inverse(cov)
     else:
         cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
     return cov, cov_inv_sqrt
Exemple #29
0
def sample_ellipsoid(key, center, radii, rotation, unit_cube_constraint=False):
    """
    Sample uniformly inside an ellipsoid.
    When unit_cube_constraint=True then during the sampling when a random radius is chosen, the radius is constrained.

    u(t) = R @ (t * n) + c
    u(t) == 1
    1-c = t * R@n
    t = (1 - c)/R@n take minimum t satisfying this
    likewise for zero intersection
    Args:
        key:
        center: [D]
        radii: [D]
        rotation: [D,D]

    Returns: [D]

    """
    direction_key, radii_key = random.split(key, 2)
    direction = random.normal(direction_key, shape=radii.shape)
    if unit_cube_constraint:
        direction = direction / jnp.linalg.norm(direction)
        R = rotation * radii
        D = R @ direction
        t0 = -center / D
        t1 = jnp.reciprocal(D) + t0
        t0 = jnp.where(t0 < 0., jnp.inf, t0)
        t1 = jnp.where(t1 < 0., jnp.inf, t1)
        t = jnp.minimum(jnp.min(t0), jnp.min(t1))
        t = jnp.minimum(t, 1.)
        return jnp.exp(
            jnp.log(random.uniform(radii_key, minval=0., maxval=t)) /
            radii.size) * D + center
    log_norm = jnp.log(jnp.linalg.norm(direction))
    log_radius = jnp.log(random.uniform(radii_key)) / radii.size
    # x = direction * (radius/norm)
    x = direction * jnp.exp(log_radius - log_norm)
    return circle_to_ellipsoid(x, center, radii, rotation)
Exemple #30
0
 def body(state):
     (key, i, u_test, x_test, log_L_test) = state
     key, uniform_key, beta_key = random.split(key, 3)
     # [M]
     U_scale = random.uniform(uniform_key,
                              shape=spawn_point_U.shape,
                              minval=t_L,
                              maxval=t_R)
     t_shrink = random.beta(beta_key, live_points_U.shape[0],
                            1)**jnp.reciprocal(spawn_point_U.size)
     u_test_white = U_scale / t_shrink
     # y_j =
     #    = dx + sum_i p_i * u_i
     #    = dx + R @ u
     # x_i = x0_i + R_ij u_j
     if whiten:
         u_test = L @ (spawn_point_U + R @ u_test_white) + u_mean
     else:
         u_test = u_test_white
     u_test = jnp.clip(u_test, 0., 1.)
     x_test = prior_transform(u_test)
     log_L_test = loglikelihood_from_constrained(**x_test)
     return (key, i + 1, u_test, x_test, log_L_test)