def to_minimize(flat_theta): theta = reconstruct(flat_theta, summary, jnp.reshape) theta = apply_transformation(theta, "log_", jnp.exp, "") lik = calculate_likelihood( theta, species_ids, fg_covs, bg_covs, fg_covs_thin, bg_covs_thin, quad_weights, counts, n_s, n_fg, ) kl = calculate_kl(theta) prior = jnp.sum( gamma.logpdf(theta["w_prior_var"], 0.5, scale=1.0 / n_c)) prior = prior + jnp.sum( norm.logpdf(theta["w_prior_mean"], 0.0, scale=jnp.sqrt(1.0 / n_c))) return -(lik - kl + prior)
def loss_MAP(mu, tau_unc, D, i0, i1, mu0, beta=1.0, gamma_shape=1.0, gamma_rate=1.0, alpha=1.0): mu_i, mu_j = mu[i0], mu[i1] tau = EPSILON + jax.nn.softplus(SCALE * tau_unc) tau_i, tau_j = tau[i0], tau[i1] tau_ij_inv = tau_i * tau_j / (tau_i + tau_j) log_tau_ij_inv = jnp.log(tau_i) + jnp.log(tau_j) - jnp.log(tau_i + tau_j) d = jnp.linalg.norm(mu_i - mu_j, ord=2, axis=1, keepdims=1) log_llh = (jnp.log(D) + log_tau_ij_inv - 0.5 * tau_ij_inv * (D - d)**2 + jnp.log(i0e(tau_ij_inv * D * d))) # index of points in prior log_mu = multivariate_normal.logpdf(mu, mean=mu0, cov=beta * jnp.eye(2)) log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate) return jnp.sum(log_llh) + jnp.sum(log_mu) + jnp.sum(log_tau)
def log_joint(self, theta: jnp.DeviceArray) -> jnp.DeviceArray: betas = theta[:2] sigma = theta[2] beta_prior = norm.logpdf(betas, 0, 10).sum() sigma_prior = gamma.logpdf(sigma, a=1, scale=2).sum() yhat = jnp.inner(self.x, betas) likelihood = norm.logpdf(self.y, yhat, sigma).sum() return beta_prior + sigma_prior + likelihood
def log_normal_gamma_prior(mu, tau, mu0=0.0, beta=1.0, gamma_shape=1.0, gamma_rate=1.0): log_mu = multivariate_normal.logpdf(mu, mean=0.0, cov=beta).sum() # sum of 2 dimensions log_tau = gamma.logpdf(tau, a=gamma_shape, scale=1.0 / gamma_rate) # print("[DEBUG] Log prior: ", log_mu.shape, log_tau.shape) return log_mu + log_tau
import jax.numpy as jnp from ml_tools.constrain import apply_transformation from functools import partial from ml_tools.jax_kernels import matern_kernel_32, bias_kernel import svgp.jax.helpers.svgp_spec as sv from ml_tools.flattening import flatten_and_summarise, reconstruct from svgp.jax.quadrature import expectation_1d from jax.scipy.stats import gamma from ml_tools.jax import convert_decorator from jax import jit, value_and_grad from svgp.jax.likelihoods import bernoulli_probit_lik from scipy.optimize import minimize gamma_default_lscale_prior_fn = lambda params: jnp.sum( gamma.logpdf(params["lengthscales"], 3.0, scale=3.0) ) constrain_positive = partial( apply_transformation, search_key="log_", transformation=jnp.exp, replace_with="" ) def ard_kernel_currier(params, base_kernel=matern_kernel_32): """A kernel getter to be used with get_kernel_fun. Given a parameter_dict containing the entries "lengthscales" and "alpha", returns the base_kernel ready to evaluate.""" curried_kernel_fun = lambda x1, x2, diag_only=False: base_kernel( x1, x2, params["lengthscales"], params["alpha"], diag_only )
def nLL_sep(x, uncens_obs, uncens_gammas, cens_obs, cens_gammas): a, scale = x uncens = jnp.dot(uncens_gammas, gamma.logpdf(uncens_obs, a=a, scale=scale)) cens = jnp.dot(cens_gammas, gammaincc(a, cens_obs / scale)) return -1 * (uncens + cens)