def model(sigma_sq, alpha, sigma_sq_mu): pi = ph.dirichlet.rvs(alpha) mu = ph.norm.rvs(0., np.sqrt(sigma_sq_mu), size=[5, 2]) z = ph.categorical.rvs(pi, size=200) z_one_hot = one_hot(z, len(pi)) mu_z = np.dot(z_one_hot, mu) x = ph.norm.rvs(mu_z, np.sqrt(sigma_sq)) return x
def log_joint(p, x, alpha): log_prior = np.sum((alpha - 1) * np.log(p)) log_prior += -special.gammaln(alpha).sum() + special.gammaln(alpha.sum()) # TODO(mhoffman): We should make it possible to only use one-hot # when necessary. one_hot_x = one_hot(x, alpha.shape[0]) log_likelihood = np.sum(np.dot(one_hot_x, np.log(p))) return log_prior + log_likelihood
def log_joint(x, pi, z, mu, sigma_sq, alpha, sigma_sq_mu): log_p_pi = log_probs.dirichlet_gen_log_prob(pi, alpha) log_p_mu = log_probs.norm_gen_log_prob(mu, 0, np.sqrt(sigma_sq_mu)) log_p_z = log_probs.categorical_gen_log_prob(z, pi) z_one_hot = one_hot(z, len(pi)) mu_z = np.dot(z_one_hot, mu) log_p_x = log_probs.norm_gen_log_prob(x, mu_z, np.sqrt(sigma_sq)) return log_p_pi + log_p_z + log_p_mu + log_p_x
def log_joint(pi, z, mu, tau): log_p_pi = log_probs.dirichlet_gen_log_prob(pi, alpha) log_p_mu = log_probs.norm_gen_log_prob(mu, 0., 1. / np.sqrt(kappa * tau)) log_p_z = log_probs.categorical_gen_log_prob(z, pi) log_p_tau = log_probs.gamma_gen_log_prob(tau, a, b) z_one_hot = one_hot(z, len(pi)) mu_z = np.dot(z_one_hot, mu) log_p_x = log_probs.norm_gen_log_prob(x, mu_z, 1. / np.sqrt(tau)) return log_p_pi + log_p_z + log_p_mu + log_p_x
def log_joint(x, pi, z, mu, sigma_sq, alpha, sigma_sq_mu): log_p_pi = log_probs.dirichlet_gen_log_prob(pi, alpha) log_p_mu = log_probs.norm_gen_log_prob(mu, 0, np.sqrt(sigma_sq_mu)) z_one_hot = one_hot(z, len(pi)) log_p_z = np.einsum('ij,j->', z_one_hot, np.log(pi)) mu_z = np.einsum('ij,jk->ik', z_one_hot, mu) log_p_x = log_probs.norm_gen_log_prob(x, mu_z, np.sqrt(sigma_sq)) return log_p_pi + log_p_z + log_p_mu + log_p_x
def structured_categorical_logpdf(x, natparam, dim): logp = 0 alphabet = 'abcdefghijklmnopqrstuvwxyz' factor_iter = chain(natparam.single_onehot_xis.iteritems(), natparam.joint_onehot_xis.iteritems()) for factor, param in factor_iter: factor_idxs = ''.join(alphabet[i] for i in range(len(factor))) in_formula = ','.join([factor_idxs] + [alphabet[i] for i in range(len(factor))]) formula = '{}->'.format(in_formula) logp += np.einsum(formula, param, *[one_hot(x[node], dim) for node in factor]) return logp
def fun(x, y, e): one_hot_e = tracers.one_hot(e, x.shape[0]) return np.einsum('ab,bc,ad,dc->ac', one_hot_e, x, one_hot_e, y)
def log_joint(x, probs): one_hot_x = one_hot(x, vocab_size) return np.sum(np.dot(one_hot_x, np.log(probs)))