예제 #1
0
    def to_minimize(flat_theta):

        theta = reconstruct(flat_theta, summary, jnp.reshape)
        theta = apply_transformation(theta, "log_", jnp.exp, "")

        lik = calculate_likelihood(
            theta,
            species_ids,
            fg_covs,
            bg_covs,
            fg_covs_thin,
            bg_covs_thin,
            quad_weights,
            counts,
            n_s,
            n_fg,
        )
        kl = calculate_kl(theta)

        prior = jnp.sum(
            gamma.logpdf(theta["w_prior_var"], 0.5, scale=1.0 / n_c))
        prior = prior + jnp.sum(
            norm.logpdf(theta["w_prior_mean"], 0.0, scale=jnp.sqrt(1.0 / n_c)))

        return -(lik - kl + prior)
예제 #2
0
def loss_MAP(mu,
             tau_unc,
             D,
             i0,
             i1,
             mu0,
             beta=1.0,
             gamma_shape=1.0,
             gamma_rate=1.0,
             alpha=1.0):
    mu_i, mu_j = mu[i0], mu[i1]
    tau = EPSILON + jax.nn.softplus(SCALE * tau_unc)
    tau_i, tau_j = tau[i0], tau[i1]

    tau_ij_inv = tau_i * tau_j / (tau_i + tau_j)
    log_tau_ij_inv = jnp.log(tau_i) + jnp.log(tau_j) - jnp.log(tau_i + tau_j)

    d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1)

    log_llh = (jnp.log(D) + log_tau_ij_inv - 0.5 * tau_ij_inv * (D - d)**2 +
               jnp.log(i0e(tau_ij_inv * D * d)))

    # index of points in prior
    log_mu = multivariate_normal.logpdf(mu, mean=mu0, cov=beta * jnp.eye(2))
    log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate)

    return jnp.sum(log_llh) + jnp.sum(log_mu) + jnp.sum(log_tau)
예제 #3
0
파일: models.py 프로젝트: sagar87/jaxvi
    def log_joint(self, theta: jnp.DeviceArray) -> jnp.DeviceArray:
        betas = theta[:2]
        sigma = theta[2]

        beta_prior = norm.logpdf(betas, 0, 10).sum()
        sigma_prior = gamma.logpdf(sigma, a=1, scale=2).sum()
        yhat = jnp.inner(self.x, betas)
        likelihood = norm.logpdf(self.y, yhat, sigma).sum()

        return beta_prior + sigma_prior + likelihood
예제 #4
0
def log_normal_gamma_prior(mu,
                           tau,
                           mu0=0.0,
                           beta=1.0,
                           gamma_shape=1.0,
                           gamma_rate=1.0):
    log_mu = multivariate_normal.logpdf(mu, mean=0.0,
                                        cov=beta).sum()  # sum of 2 dimensions
    log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate)
    # print("[DEBUG] Log prior: ", log_mu.shape, log_tau.shape)
    return log_mu + log_tau
예제 #5
0
import jax.numpy as jnp
from ml_tools.constrain import apply_transformation
from functools import partial
from ml_tools.jax_kernels import matern_kernel_32, bias_kernel
import svgp.jax.helpers.svgp_spec as sv
from ml_tools.flattening import flatten_and_summarise, reconstruct
from svgp.jax.quadrature import expectation_1d
from jax.scipy.stats import gamma
from ml_tools.jax import convert_decorator
from jax import jit, value_and_grad
from svgp.jax.likelihoods import bernoulli_probit_lik
from scipy.optimize import minimize


gamma_default_lscale_prior_fn = lambda params: jnp.sum(
    gamma.logpdf(params["lengthscales"], 3.0, scale=3.0)
)

constrain_positive = partial(
    apply_transformation, search_key="log_", transformation=jnp.exp, replace_with=""
)


def ard_kernel_currier(params, base_kernel=matern_kernel_32):
    """A kernel getter to be used with get_kernel_fun. Given a parameter_dict
    containing the entries "lengthscales" and "alpha", returns the base_kernel
    ready to evaluate."""

    curried_kernel_fun = lambda x1, x2, diag_only=False: base_kernel(
        x1, x2, params["lengthscales"], params["alpha"], diag_only
    )
예제 #6
0
def nLL_sep(x, uncens_obs, uncens_gammas, cens_obs, cens_gammas):
    a, scale = x
    uncens = jnp.dot(uncens_gammas, gamma.logpdf(uncens_obs, a=a, scale=scale))
    cens = jnp.dot(cens_gammas, gammaincc(a, cens_obs / scale))
    return -1 * (uncens + cens)