示例#1
0
def metropolis_hastings(theta, theta_star, phi, phi_star, M, key,
                        log_posterior):

    log_prob_phi_star = jnp.sum(
        norm.logpdf(phi_star, jnp.zeros_like(theta), jnp.sqrt(M)))
    log_prob_phi = jnp.sum(norm.logpdf(phi, jnp.zeros_like(theta),
                                       jnp.sqrt(M)))

    log_prob_theta = log_posterior(theta)
    log_prob_theta_star = log_posterior(theta_star)

    log_numerator = log_prob_theta_star + log_prob_phi_star
    log_denominator = log_prob_theta + log_prob_phi

    log_accept_prob = log_numerator - log_denominator

    accept_prob = jnp.exp(log_accept_prob)

    key, subkey = split(key)

    rand_draw = uniform(subkey)

    new_theta = cond(rand_draw < accept_prob, lambda _: theta_star,
                     lambda _: theta, theta)

    return key, new_theta, accept_prob
示例#2
0
文件: models.py 项目: sagar87/jaxvi
    def log_joint(self, theta: jnp.DeviceArray) -> jnp.DeviceArray:
        betas = theta[:2]
        sigma = theta[2]

        beta_prior = norm.logpdf(betas, 0, 10).sum()
        sigma_prior = gamma.logpdf(sigma, a=1, scale=2).sum()
        yhat = jnp.inner(self.x, betas)
        likelihood = norm.logpdf(self.y, yhat, sigma).sum()

        return beta_prior + sigma_prior + likelihood
示例#3
0
文件: mcmc.py 项目: silky/ml_tools
def run_hmc_steps(theta, eps, Lmax, key, log_posterior,
                  log_posterior_grad_theta, diagonal_mass_matrix):
    # Diagonal mass matrix: diagonal entries of M (a vector)

    inverse_diag_mass = 1. / diagonal_mass_matrix

    key, subkey = random.split(key)

    # Location-scale transform to get the right variance
    # TODO: Check!
    phi = random.normal(
        subkey, shape=(theta.shape[0], )) * np.sqrt(diagonal_mass_matrix)

    start_theta = theta
    start_phi = phi

    cur_grad = log_posterior_grad_theta(theta)

    key, subkey = random.split(key)

    L = np_classic.random.randint(1, Lmax)

    for cur_l in range(L):
        phi = phi + 0.5 * eps * cur_grad
        theta = theta + eps * inverse_diag_mass * phi
        cur_grad = log_posterior_grad_theta(theta)
        phi = phi + 0.5 * eps * cur_grad

    # Compute (log) acceptance probability
    proposed_log_post = log_posterior(theta)
    previous_log_post = log_posterior(start_theta)

    proposed_log_phi = np.sum(
        norm.logpdf(phi, scale=np.sqrt(diagonal_mass_matrix)))
    previous_log_phi = np.sum(
        norm.logpdf(start_phi, scale=np.sqrt(diagonal_mass_matrix)))

    print(f'Proposed log posterior is: {proposed_log_post}.'
          f'Previous was {previous_log_post}.')

    if (np.isinf(proposed_log_post) or np.isnan(proposed_log_post)
            or np.isneginf(proposed_log_post)):
        # Reject
        was_accepted = False
        new_theta = start_theta
        # FIXME: What number to put here?
        log_r = -10
        return was_accepted, log_r, new_theta

    log_r = (proposed_log_post + proposed_log_phi - previous_log_post -
             previous_log_phi)

    was_accepted, new_theta = acceptance_step(log_r, theta, start_theta, key)

    return was_accepted, log_r, new_theta
def calculate_likelihood(x, mu, a, theta, y):

    margin, was_retirement, bo5 = y

    sigma_obs = (1 - bo5) * theta["sigma_obs"] + bo5 * theta["sigma_obs_bo5"]

    # If it wasn't a retirement:
    margin_prob = norm.logpdf(margin, theta["a1"] * (a @ x) + theta["a2"],
                              sigma_obs)

    win_prob = jnp.log(expit((1 + theta["bo5_factor"] * bo5) * b * a @ x))

    # Otherwise:
    # Only the loser's skill matters:
    n_skills = x.shape[0]

    loser_x = x[n_skills // 2:]
    loser_mu = mu[n_skills // 2:]
    loser_a = -a[n_skills // 2:]

    loser_expected_skill = loser_a @ loser_mu
    loser_actual_skill = loser_a @ loser_x

    full_ret_factor = theta["ret_factor"] * (
        1 - theta["skill_factor"] * expit(b * loser_expected_skill))

    ret_prob = expit(full_ret_factor *
                     (loser_expected_skill - loser_actual_skill) +
                     theta["ret_intercept"])

    prob_retirement = jnp.log(ret_prob)
    prob_not_retirement = jnp.log(1 - ret_prob)

    return (1 - was_retirement) * (margin_prob + win_prob + prob_not_retirement
                                   ) + was_retirement * prob_retirement
示例#5
0
 def loss_fun(params,
              rng,
              data,
              batch_size=None,
              n=None,
              loss_type="nlp",
              reduce="sum"):
     """
     :param batch_size: How large a batch to subselect from the provided data
     :param n: The total size of the dataset (to multiply batch estimate by)
     """
     assert loss_type in ("nlp", "mse")
     inputs, targets = data
     n = inputs.shape[0] if n is None else n
     if batch_size is not None:
         rng, rng_batch = random.split(rng)
         i = random.permutation(rng_batch, n)[:batch_size]
         inputs, targets = inputs[i], targets[i]
     preds = apply_fun(params, rng, inputs).squeeze()
     mean_loss = (
         -norm.logpdf(targets.squeeze(), preds, params["noise"]).mean()
         if loss_type == "nlp" else np.power(targets.squeeze() -
                                             preds, 2).mean())
     if reduce == "sum":
         loss = n * mean_loss
     elif reduce == "mean":
         loss = mean_loss
     return loss
示例#6
0
 def log_prob(self, x):
     log_prob = multivariate_normal.logpdf(x,
                                           loc=self._mean,
                                           scale=self._scale)
     # Sum over parameter axes.
     sum_axis = [-(i + 1) for i in range(len(self._param_shape))]
     return jnp.sum(log_prob, axis=sum_axis)
def log_prob_component(i, theta):
    radii, scales = construct_array(i, theta)

    #@jit
    def component(distance_matrix):
        return compute_OBC_energy_vectorized(distance_matrix, radii, scales,
                                             charges[i])

    # TODO: vmap
    W_F = np.array(list(map(component, distance_matrices[i])))
    #W_F = vmap(component)(distance_matrices[i])
    w_F = W_F * kj_mol_to_kT
    pred_free_energy = one_sided_exp(w_F)
    if gaussian_ll:
        return norm.logpdf(pred_free_energy,
                           loc=expt_means[i],
                           scale=expt_uncs[i]**2)
    else:
        # TODO : fix
        # https://github.com/scipy/scipy/blob/c3fa90dcfcaef71658744c73578e9e7d915c81e9/scipy/stats/_continuous_distns.py#L5207
        # def _logpdf(self, x, df):
        #         r = df*1.0
        #         lPx = sc.gammaln((r+1)/2)-sc.gammaln(r/2)
        #         lPx -= 0.5*np.log(r*np.pi) + (r+1)/2*np.log(1+(x**2)/r)
        #         return lPx
        raise (NotImplementedError)

        return student_t.logpdf(pred_free_energy,
                                loc=expt_means[i],
                                scale=expt_uncs[i]**2,
                                df=7)
示例#8
0
    def to_minimize(flat_theta):

        theta = reconstruct(flat_theta, summary, jnp.reshape)
        theta = apply_transformation(theta, "log_", jnp.exp, "")

        lik = calculate_likelihood(
            theta,
            species_ids,
            fg_covs,
            bg_covs,
            fg_covs_thin,
            bg_covs_thin,
            quad_weights,
            counts,
            n_s,
            n_fg,
        )
        kl = calculate_kl(theta)

        prior = jnp.sum(
            gamma.logpdf(theta["w_prior_var"], 0.5, scale=1.0 / n_c))
        prior = prior + jnp.sum(
            norm.logpdf(theta["w_prior_mean"], 0.0, scale=jnp.sqrt(1.0 / n_c)))

        return -(lik - kl + prior)
示例#9
0
    def _sis_step(self, key, log_weights_prev, mu_prev, xparticles_prev, yobs):
        """
        Compute one step of the sequential-importance-sampling algorithm
        at time t.
        
        Parameters
        ----------
        key: jax.random.PRNGKey
            key to sample particle.
        mu_prev: array(n_particles)
            Term carrying past cumulate values.
        xsamp_prev: array(n_particles)
            Samples / particles from the latent space at t-1
        yobs: float
            Observation at time t.
        """

        key_particles = jax.random.split(key, len(xparticles_prev))

        # 1. Sample from proposal
        xparticles = jax.vmap(self.sample_latent_step)(key_particles,
                                                       xparticles_prev)
        # 2. Evaluate unnormalised weights
        # 2.1 Compute new mean
        mu = self.beta * mu_prev + xparticles
        # 2.2 Compute log-unnormalised weights

        log_weights = log_weights_prev + norm.logpdf(
            yobs, loc=mu, scale=jnp.sqrt(self.q))

        return (log_weights, mu, xparticles), log_weights
示例#10
0
def diag_gaussian_logpdf(x, mean=None, logvar=None):
    # Log of PDF for diagonal Gaussian for a single point
    D = x.shape[0]
    if (mean == None) and (logvar == None):
        # Standard Gaussian
        mean, logvar = jnp.zeros_like(x), jnp.zeros_like(x)
    # manual = -0.5 * (jnp.log(2*jnp.pi) + logvar + (x-mean)**2 * jnp.exp(-logvar))
    # print('manual check', jnp.sum(manual))
    return jnp.sum(norm.logpdf(x, loc=mean, scale=jnp.exp(0.5 * logvar)))
示例#11
0
def calculate_likelihood(x, mu, a, theta, y):

    margin = y[0]

    margin_prob = norm.logpdf(margin, theta["a1"] * (a @ x) + theta["a2"],
                              theta["sigma_obs"])

    win_prob = jnp.log(expit(b * a @ x))

    return margin_prob + win_prob
示例#12
0
 def _blended_loss_landscape(self, positions):
     '''Assumes two-dimensional space'''
     prob_target = self.classifier.predict(positions)
     target = numpy.ones_like(prob_target) * self._target_class
     z = self.vae.encode(positions)[:,:self.vae.n_latent_dims]
     reconstructed_x = self.vae.decode(z)
     log_likelihood_data = (
         norm.logpdf(positions, reconstructed_x, self.vae.x_var ** 0.5)).sum(axis=1)
     likelihood_data_term = LOSS_BLEND_PARAM * (log_likelihood_data + np.log(prob_target.squeeze()))
     return _binary_crossentropy(target, prob_target).squeeze() - likelihood_data_term
 def loss_fun(params, rng, data, cache=None):
     x, y = data
     y = y.squeeze()
     n_batch = x.shape[0]
     loc, scale = gaussian_fun(params,
                               rng,
                               x,
                               noise=True,
                               diag=True,
                               cache=cache)
     return -(n / n_batch * norm.logpdf(y, loc, scale).sum() - kl(params))
def diag_gaussian_log_density(x, mu, log_sigma):
    """
  Args:
    x: random variable
    mu: mean
    log_sigma: log standard deviation
  Return:
    log normal density.
  """
    assert x.ndim == 1
    return np.sum(norm.logpdf(x, mu, np.exp(log_sigma)), axis=-1)
def calculate_marginal_lik(x, mu, a, cov_mat, theta, y):

    margin, was_retirement, bo5 = y

    sigma_obs = (1 - bo5) * theta["sigma_obs"] + bo5 * theta["sigma_obs_bo5"]

    latent_mean, latent_var = weighted_sum(x, cov_mat, a)

    # If it wasn't a retirement:
    margin_prob = norm.logpdf(
        margin,
        theta["a1"] * (latent_mean) + theta["a2"],
        jnp.sqrt(sigma_obs**2 + theta["a1"]**2 * latent_var),
    )

    win_prob = jnp.log(
        logistic_normal_integral_approx(
            (1 + theta["bo5_factor"] * bo5) * b * latent_mean,
            (1 + theta["bo5_factor"] * bo5)**2 * b**2 * latent_var,
        ))

    n_skills = x.shape[0]

    # Otherwise:
    # Only the loser's skill matters:
    loser_x = x[n_skills // 2:]
    loser_mu = mu[n_skills // 2:]
    loser_a = -a[n_skills // 2:]
    loser_cov_mat = cov_mat[n_skills // 2:, n_skills // 2:]

    loser_actual_mean, loser_actual_var = weighted_sum(loser_x, loser_cov_mat,
                                                       loser_a)
    loser_expected_skill = loser_a @ loser_mu

    full_ret_factor = theta["ret_factor"] * (
        1 - theta["skill_factor"] * expit(b * loser_expected_skill))

    ret_prob = logistic_normal_integral_approx(
        full_ret_factor * (loser_expected_skill - loser_actual_mean) +
        theta["ret_intercept"],
        full_ret_factor**2 * loser_actual_var,
    )

    prob_retirement = jnp.log(ret_prob)
    prob_not_retirement = jnp.log(1 - ret_prob)

    return (1 - was_retirement) * (win_prob + margin_prob + prob_not_retirement
                                   ) + (was_retirement * prob_retirement)
示例#16
0
def calculate_marginal_lik(x, mu, a, cov_mat, theta, y):

    margin = y[0]

    latent_mean, latent_var = weighted_sum(x, cov_mat, a)

    margin_prob = norm.logpdf(
        margin,
        theta["a1"] * (latent_mean) + theta["a2"],
        jnp.sqrt(theta["sigma_obs"]**2 + theta["a1"]**2 * latent_var),
    )

    win_prob = jnp.log(
        logistic_normal_integral_approx(b * latent_mean, b**2 * latent_var))

    return win_prob + margin_prob
示例#17
0
    def _smc_step(self, key, log_weights_prev, mu_prev, xparticles_prev, yobs):
        n_particles = len(xparticles_prev)
        key, key_particles = jax.random.split(key)
        key_particles = jax.random.split(key_particles, n_particles)

        # 1. Resample particles
        weights = self._obtain_weights(log_weights_prev)
        ix_sampled = jax.random.choice(key,
                                       n_particles,
                                       p=weights,
                                       shape=(n_particles, ))
        xparticles_prev_sampled = xparticles_prev[ix_sampled]
        mu_prev_sampled = mu_prev[ix_sampled]
        # 2. Propagate particles
        xparticles = jax.vmap(self.sample_latent_step)(key_particles,
                                                       xparticles_prev_sampled)
        # 3. Concatenate
        mu = self.beta * mu_prev_sampled + xparticles

        # ToDo: return dictionary of log_weights and sampled indices
        log_weights = norm.logpdf(yobs, loc=mu, scale=jnp.sqrt(self.q))
        dict_carry = {"log_weights": log_weights, "indices": ix_sampled}
        return (log_weights, mu, xparticles_prev_sampled), dict_carry
示例#18
0
def _blended_revise_objective(latent_pos, chosen_point,
                      vae, classifier,
                      calc_loss,
                      dist_weight, calc_dist, target_class):
    reconstructed_x = vae.decode(latent_pos)
    
    distance = calc_dist(chosen_point, reconstructed_x)
    distance_term = dist_weight * distance
    
    reconstructed_prob_target = classifier.predict(reconstructed_x)
    target_loss = calc_loss(np.array([1]), reconstructed_prob_target)
    
    reconstructed_z = vae.encode(reconstructed_x)[:vae.n_latent_dims]
    rereconstructed_x = vae.decode(reconstructed_z)
    log_likelihood_data = (norm.logpdf(reconstructed_x, rereconstructed_x, vae.x_var ** 0.5)).sum()
    likelihood_data_term = LOSS_BLEND_PARAM * (log_likelihood_data + np.log(reconstructed_prob_target.squeeze()))
    
    objective = target_loss + distance_term - likelihood_data_term
    return objective.squeeze(), {
        'reconstructed_x': reconstructed_x,
        'reconstructed_prob_target': reconstructed_prob_target,
        'log_likelihood_data': log_likelihood_data
    }
示例#19
0
def logredshiftprior(x, a, b):
    lognorm = np.log(norm.cdf(zmax, loc=a, scale=b) - norm.cdf(0, loc=a, scale=b))
    return norm.logpdf(x, loc=a, scale=b) - lognorm
示例#20
0
 def _mv_log_pdf(y, x, s):
     z = jnp.dot(self.W, x)
     return jnp.sum(logpdf(y, z, s))
示例#21
0
def vt_func(eta, y, mu, v, Zt):
    log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))
    log_term = log_term + norm.logpdf(eta, mu, v)

    return eta**2 * jnp.exp(log_term) / Zt
示例#22
0
def Zt_func(eta, y, mu, v):
    log_term = y * log_sigmoid(eta) + (1 - y) * jnp.log1p(-sigmoid(eta))
    log_term = log_term + norm.logpdf(eta, mu, v)

    return jnp.exp(log_term)
示例#23
0
def gaussian_lik(y, f, f_sd):

    return norm.logpdf(y, f, f_sd)
示例#24
0
 def log_prob(self, inputs):
     return norm.logpdf(inputs).sum(1)
示例#25
0
def f_jvp(primals, tangents):
    z, = primals
    z_dot, = tangents
    primal_out = norm_logcdf(z)
    tangent_out = jnp.exp(norm.logpdf(z) - primal_out) * z_dot
    return primal_out, tangent_out
示例#26
0
def log_pdf_normal(s):
    """ Log-pdf for a Gaussian distribution w. mean 0 std 1"""
    return jnp.sum(norm.logpdf(s))
示例#27
0
 def _likelihood_landscape(self, positions):
     z = self.vae.encode(positions)
     new_x = self.vae.decode(z[:,:2])
     log_prob = numpy.sum(norm.logpdf(positions, new_x, self.vae.x_var ** 0.5), axis=1)
     return log_prob
示例#28
0
 def log_prob(self, value):
     assert_array(value, shape=(..., ) + self.batch_shape)
     return norm.logpdf(value, loc=self._loc, scale=self._scale)
示例#29
0
文件: advi.py 项目: zhangfeilong/jax
def funnel_log_density(params):
    return norm.logpdf(params[0], 0, np.exp(params[1])) + \
           norm.logpdf(params[1], 0, 1.35)
def log_prior(theta):
    return np.sum(norm.logpdf(theta - prior_location))