コード例 #1
0
 def log_joint(x, w, epsilon, tau, alpha, beta):
   log_p_epsilon = log_probs.norm_gen_log_prob(epsilon, 0, 1)
   log_p_w = log_probs.norm_gen_log_prob(w, 0, 1)
   log_p_tau = log_probs.gamma_gen_log_prob(tau, alpha, beta)
   # TODO(mhoffman): The transposed version below should work.
   # log_p_x = log_probs.norm_gen_log_prob(x, np.dot(epsilon, w), 1. / np.sqrt(tau))
   log_p_x = log_probs.norm_gen_log_prob(x, np.einsum('ik,jk->ij', epsilon, w),
                                         1. / np.sqrt(tau))
   return log_p_epsilon + log_p_w + log_p_tau + log_p_x
    def log_joint(pi, z, mu, tau):
        log_p_pi = log_probs.dirichlet_gen_log_prob(pi, alpha)
        log_p_mu = log_probs.norm_gen_log_prob(mu, 0.,
                                               1. / np.sqrt(kappa * tau))
        log_p_z = log_probs.categorical_gen_log_prob(z, pi)
        log_p_tau = log_probs.gamma_gen_log_prob(tau, a, b)

        z_one_hot = one_hot(z, len(pi))
        mu_z = np.dot(z_one_hot, mu)
        log_p_x = log_probs.norm_gen_log_prob(x, mu_z, 1. / np.sqrt(tau))
        return log_p_pi + log_p_z + log_p_mu + log_p_x
コード例 #3
0
def log_joint(tau, mu, x, alpha, beta, kappa, mu0):
    log_p_tau = log_probs.gamma_gen_log_prob(tau, alpha, beta)
    log_p_mu = log_probs.norm_gen_log_prob(mu, mu0, 1. / np.sqrt(kappa * tau))
    log_p_x = log_probs.norm_gen_log_prob(x, mu, 1. / np.sqrt(tau))
    return log_p_tau + log_p_mu + log_p_x
コード例 #4
0
 def log_joint(x, precision, a, b):
   log_p_precision = log_probs.gamma_gen_log_prob(precision, a, b)
   log_p_x = log_probs.norm_gen_log_prob(x, 0., 1. / np.sqrt(precision))
   return log_p_precision + log_p_x
コード例 #5
0
 def log_joint(x, y, a, b):
   log_prior = log_probs.gamma_gen_log_prob(x, a, b)
   log_likelihood = np.sum(log_probs.gamma_gen_log_prob(y, a, a * x))
   return log_prior + log_likelihood
コード例 #6
0
 def log_joint(x, y, a, b):
   log_prior = log_probs.gamma_gen_log_prob(x, a, b)
   log_likelihood = np.sum(-special.gammaln(y + 1) + y * np.log(x) - x)
   return log_prior + log_likelihood