def main(argv):
  del argv  # Unused.

  n_examples = 10
  a = 1.3
  b = 2.2
  kappa = 1.5
  mu0 = 0.3
  tau = np.random.gamma(a, 1. / b)
  mu = np.random.normal(mu0, 1. / np.sqrt(tau * kappa))
  x = np.random.normal(mu, 1. / np.sqrt(tau), n_examples)
  all_args = [a, b, kappa, mu0, tau, mu, x]
  all_args_ex_mu = [a, b, kappa, mu0, tau, x]

  log_joint = ph.make_log_joint_fn(model)

  mu_conditional_factory = conjugacy.complete_conditional(
      log_joint, 5, conjugacy.SupportTypes.REAL, *all_args)
  mu_conditional = mu_conditional_factory(*all_args_ex_mu)

  log_p_tau = conjugacy.marginalize(log_joint, 5, conjugacy.SupportTypes.REAL,
                                    *all_args)
  tau_conditional_factory = conjugacy.complete_conditional(
      log_p_tau, 4, conjugacy.SupportTypes.NONNEGATIVE, *all_args_ex_mu)
  tau_conditional = tau_conditional_factory(*[a, b, kappa, mu0, x])

  print('True tau: {}'.format(tau))
  print('tau posterior is gamma({}, {}). Mean is {}, std. dev. is {}.'.format(
      tau_conditional.args[0], 1. / tau_conditional.args[2],
      tau_conditional.args[0] * tau_conditional.args[2],
      np.sqrt(tau_conditional.args[0]) * tau_conditional.args[2]))
  print()
  print('True mu: {}'.format(mu))
  print('mu posterior given tau is normal({}, {})'.format(
      mu_conditional.args[0], mu_conditional.args[1]))
  def testFactorAnalysis(self):
    def log_joint(x, w, epsilon, tau, alpha, beta):
      log_p_epsilon = log_probs.norm_gen_log_prob(epsilon, 0, 1)
      log_p_w = log_probs.norm_gen_log_prob(w, 0, 1)
      log_p_tau = log_probs.gamma_gen_log_prob(tau, alpha, beta)
      # TODO(mhoffman): The transposed version below should work.
      # log_p_x = log_probs.norm_gen_log_prob(x, np.dot(epsilon, w), 1. / np.sqrt(tau))
      log_p_x = log_probs.norm_gen_log_prob(x, np.einsum('ik,jk->ij', epsilon, w),
                                            1. / np.sqrt(tau))
      return log_p_epsilon + log_p_w + log_p_tau + log_p_x

    n_examples = 20
    D = 10
    K = 5
    alpha = 2.
    beta = 8.
    tau = np.random.gamma(alpha, beta)
    w = np.random.normal(loc=0, scale=1, size=[D, K])
    epsilon = np.random.normal(loc=0, scale=1, size=[n_examples, K])
    x = np.random.normal(loc=epsilon.dot(w.T), scale=np.sqrt(tau))
    all_args = [x, w, epsilon, tau, alpha, beta]

    w_conditional_factory = complete_conditional(log_joint, 1,
                                                 SupportTypes.REAL, *all_args)
    conditional = w_conditional_factory(x, epsilon, tau, alpha, beta)
    true_cov = np.linalg.inv(tau * np.einsum('nk,nl->kl', epsilon, epsilon) +
                             np.eye(K))
    true_mean = tau * np.einsum('nk,nd,kl->dl', epsilon, x, true_cov)
    for d in range(D):
      self.assertTrue(np.allclose(conditional[d].cov, true_cov))
      self.assertTrue(np.allclose(conditional[d].mean, true_mean[d]))

    epsilon_conditional_factory = complete_conditional(log_joint, 2,
                                                       SupportTypes.REAL,
                                                       *all_args)
    conditional = epsilon_conditional_factory(x, w, tau, alpha, beta)
    true_cov = np.linalg.inv(tau * np.einsum('dk,dl->kl', w, w) + np.eye(K))
    true_mean = tau * np.einsum('dk,nd,kl->nl', w, x, true_cov)
    for n in range(n_examples):
      self.assertTrue(np.allclose(conditional[n].cov, true_cov))
      self.assertTrue(np.allclose(conditional[n].mean, true_mean[n]))

    tau_conditional_factory = complete_conditional(log_joint, 3,
                                                   SupportTypes.NONNEGATIVE,
                                                   *all_args)
    conditional = tau_conditional_factory(x, w, epsilon, alpha, beta)
    true_a = alpha + 0.5 * n_examples * D
    true_b = beta + 0.5 * np.sum(np.square(x - epsilon.dot(w.T)))
    self.assertAlmostEqual(true_a, conditional.args[0])
    self.assertAlmostEqual(true_b, 1. / conditional.args[2])
def _condition_and_marginalize(log_joint, argnum, support, *args):
  sub_args = args[:argnum] + args[argnum + 1:]

  marginalized = marginalize(log_joint, argnum, support, *args)
  marginalized_value = marginalized(*sub_args)

  conditional_factory = complete_conditional(log_joint, argnum, support, *args)
  conditional = conditional_factory(*sub_args)

  return conditional, marginalized_value
示例#4
0
def main(argv):
    del argv

    n_clusters = 5
    n_dimensions = 2
    n_observations = 200

    alpha = 3.3 * np.ones(n_clusters)
    sigma_sq_mu = 1.5**2
    sigma_sq = 0.5**2

    np.random.seed(10001)

    pi = np.random.gamma(alpha)
    pi /= pi.sum()
    mu = np.random.normal(0, np.sqrt(sigma_sq_mu), [n_clusters, n_dimensions])
    z = np.random.choice(np.arange(n_clusters), size=n_observations, p=pi)
    x = np.random.normal(mu[z, :], sigma_sq)

    pi_est = np.ones(n_clusters) / n_clusters
    z_est = np.random.choice(np.arange(n_clusters),
                             size=n_observations,
                             p=pi_est)
    mu_est = np.random.normal(0., 0.01, [n_clusters, n_dimensions])

    all_args = [sigma_sq, alpha, sigma_sq_mu, pi_est, mu_est, z_est, x]

    log_joint = ph.make_log_joint_fn(model)
    pi_posterior = conjugacy.complete_conditional(
        log_joint, 3, conjugacy.SupportTypes.SIMPLEX, *all_args)
    z_posterior = conjugacy.complete_conditional(
        log_joint, 5, conjugacy.SupportTypes.INTEGER, *all_args)
    mu_posterior = conjugacy.complete_conditional(log_joint, 4,
                                                  conjugacy.SupportTypes.REAL,
                                                  *all_args)

    print('iteration\tlog_joint')
    for iteration in range(100):
        z_est[:] = z_posterior(*remove_arg(5, all_args)).rvs()
        pi_est[:] = pi_posterior(*remove_arg(3, all_args)).rvs()
        mu_est[:] = mu_posterior(*remove_arg(4, all_args)).rvs()
        print('{}\t\t{}'.format(iteration, log_joint(*all_args)))
示例#5
0
def run_gibbs(log_joint_fn, all_args, num_iterations):
    """Train model with Gibbs sampling."""
    alpha, beta, epsilon, w, tau, x = all_args
    # Form complete conditionals for Gibbs sampling.
    epsilon_conditional_factory = conjugacy.complete_conditional(
        log_joint_fn, 2, conjugacy.SupportTypes.REAL, *all_args)
    w_conditional_factory = conjugacy.complete_conditional(
        log_joint_fn, 3, conjugacy.SupportTypes.REAL, *all_args)
    tau_conditional_factory = conjugacy.complete_conditional(
        log_joint_fn, 4, conjugacy.SupportTypes.NONNEGATIVE, *all_args)
    epsilon_conditional = lambda w, tau: epsilon_conditional_factory(  # pylint: disable=g-long-lambda
        alpha, beta, w, tau, x)
    w_conditional = lambda epsilon, tau: w_conditional_factory(  # pylint: disable=g-long-lambda
        alpha, beta, epsilon, tau, x)
    tau_conditional = lambda epsilon, w: tau_conditional_factory(  # pylint: disable=g-long-lambda
        alpha, beta, epsilon, w, x)
    log_posterior = lambda epsilon, w, tau: log_joint_fn(  # pylint: disable=g-long-lambda
        alpha, beta, epsilon, w, tau, x)

    # Run training loop. Track expected log joint probability, i.e.,
    # E [ log p(xnew, params | xtrain) ]. It is estimated with 1 posterior sample.
    print('Running Gibbs...')
    epsilon = ph.norm.rvs(0, 1, size=epsilon.shape)
    w = ph.norm.rvs(0, 1, size=w.shape)
    tau = ph.gamma.rvs(alpha, scale=1. / beta)
    log_joints = []
    runtimes = []
    start = time.time()
    for t in range(num_iterations):
        epsilon = epsilon_conditional(w, tau).rvs()
        w = w_conditional(epsilon, tau).rvs()
        tau = tau_conditional(epsilon, w).rvs()
        if t % FLAGS.num_print == 0 or (t + 1) == num_iterations:
            log_joint = log_posterior(epsilon, w, tau)
            runtime = time.time() - start
            print('Iteration: {:>3d} Log Joint: {:.3f} '
                  'Runtime (s): {:.3f}'.format(t, log_joint, runtime))
            log_joints.append(log_joint)
            runtimes.append(runtime)
    return log_joints, runtimes
def make_marginal_fn():
  x1_given_y1_factory = complete_conditional(
      log_p_x1_y1, 0, SupportTypes.REAL, *([1.] * 4))
  log_p_y1 = marginalize(log_p_x1_y1, 0, SupportTypes.REAL, *([1.] * 4))

  log_p_xtt_ytt = marginalize(
      log_p_xt_xtt_ytt, 0, SupportTypes.REAL, *([1.] * 7))
  log_p_ytt = marginalize(
      log_p_xtt_ytt, 0, SupportTypes.REAL, *([1.] * 6))
  xt_conditional_factory = complete_conditional(
      log_p_xtt_ytt, 0, SupportTypes.REAL, *([1.] * 6))

  def marginal(y_list, x_scale, y_scale):
    log_p_y = log_p_y1(y_list[0], x_scale, y_scale)
    xt_conditional = x1_given_y1_factory(y_list[0], x_scale, y_scale)

    for t in range(1, len(y_list)):
      log_p_y += log_p_ytt(y_list[t], xt_conditional.args[0],
                           xt_conditional.args[1], x_scale, y_scale)
      xt_conditional = xt_conditional_factory(
          y_list[t], xt_conditional.args[0], xt_conditional.args[1], x_scale,
          y_scale)
    return log_p_y
  return marginal
  def testLinearRegression(self):
    def log_joint(X, beta, y):
      predictions = np.einsum('ij,j->i', X, beta)
      errors = y - predictions
      log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta)
      log_likelihood = np.einsum(',k,k->', -0.5, errors, errors)
      return log_prior + log_likelihood
    n_examples = 10
    n_predictors = 2
    X = np.random.randn(n_examples, n_predictors)
    beta = np.random.randn(n_predictors)
    y = np.random.randn(n_examples)
    graph = make_expr(log_joint, X, beta, y)
    graph = canonicalize(graph)

    args = graph.free_vars.keys()
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1])
    sufficient_statistics = [eval_node(node, graph.free_vars,
                                       {'X': X, 'beta': beta, 'y': y})
                             for node in sufficient_statistic_nodes]
    correct_sufficient_statistics = [
        -0.5 * beta.dot(beta), beta,
        -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta)
    ]
    self.assertTrue(_match_values(sufficient_statistics,
                                  correct_sufficient_statistics))

    _, natural_parameter_funs = _extract_conditional_factors(graph, 'beta')
    self.assertTrue(_match_values(natural_parameter_funs.keys(),
                                  ['x', 'einsum(...a,...b->...ab, x, x)',
                                   'einsum(...,...->..., x, x)'],
                                  lambda x, y: x == y))
    natural_parameter_vals = [f(X, beta, y) for f in
                              natural_parameter_funs.values()]
    correct_parameter_vals = [-0.5 * np.ones(n_predictors), -0.5 * X.T.dot(X),
                              y.dot(X)]
    self.assertTrue(_match_values(natural_parameter_vals,
                                  correct_parameter_vals))

    conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL,
                                               X, beta, y)
    conditional = conditional_factory(X, y)
    true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors))
    true_mean = true_cov.dot(y.dot(X))
    self.assertTrue(np.allclose(true_cov, conditional.cov))
    self.assertTrue(np.allclose(true_mean, conditional.mean))
  def testLinearRegression(self):
    def log_joint(X, beta, y):
      predictions = np.einsum('ij,j->i', X, beta)
      errors = y - predictions
      log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta)
      log_likelihood = np.einsum(',k,k->', -0.5, errors, errors)
      return log_prior + log_likelihood
    n_examples = 10
    n_predictors = 2
    X = np.random.randn(n_examples, n_predictors)
    beta = np.random.randn(n_predictors)
    y = np.random.randn(n_examples)
    graph = make_expr(log_joint, X, beta, y)
    graph = canonicalize(graph)

    args = graph.free_vars.keys()
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1])
    sufficient_statistics = [eval_node(node, graph.free_vars,
                                       {'X': X, 'beta': beta, 'y': y})
                             for node in sufficient_statistic_nodes]
    correct_sufficient_statistics = [
        -0.5 * beta.dot(beta), beta,
        -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta)
    ]
    self.assertTrue(_match_values(sufficient_statistics,
                                  correct_sufficient_statistics))

    new_log_joint, _, stats_funs, _ = (
        statistic_representation(log_joint, (X, beta, y),
                               (SupportTypes.REAL,), (1,)))
    beta_stat_fun = stats_funs[0]
    beta_natparam = grad_namedtuple(new_log_joint, 1)(X, beta_stat_fun(beta), y)
    correct_beta_natparam = (-0.5 * X.T.dot(X), y.dot(X),
                             -0.5 * np.ones(n_predictors))
    self.assertTrue(_match_values(beta_natparam, correct_beta_natparam))

    conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL,
                                               X, beta, y)
    conditional = conditional_factory(X, y)
    true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors))
    true_mean = true_cov.dot(y.dot(X))
    self.assertTrue(np.allclose(true_cov, conditional.cov))
    self.assertTrue(np.allclose(true_mean, conditional.mean))
  def testTwoGaussians(self):
    def log_joint(x1, x2):
      log_p_x1 = -0.5 * x1 * x1
      x_diff = x2 - x1
      log_p_x2 = -0.5 * x_diff * x_diff
      return log_p_x1 + log_p_x2

    x1 = np.random.randn()
    x2 = x1 + np.random.randn()
    all_args = [x1, x2]

    marginal_p_x2 = marginalize(log_joint, 0, SupportTypes.REAL, *all_args)
    correct_marginalized_value = (
        -0.25 * x2 * x2 - 0.5 * np.log(2.) + 0.5 * np.log(2. * np.pi))
    self.assertAlmostEqual(correct_marginalized_value, marginal_p_x2(x2))

    x2_conditional = complete_conditional(marginal_p_x2, 0, SupportTypes.REAL,
                                          x2)()
    self.assertAlmostEqual(x2_conditional.args[0], 0.)
    self.assertAlmostEqual(x2_conditional.args[1] ** 2, 2.)
  def testMixtureOfGaussians(self):
    def log_joint(x, pi, z, mu, sigma_sq, alpha, sigma_sq_mu):
      log_p_pi = log_probs.dirichlet_gen_log_prob(pi, alpha)
      log_p_mu = log_probs.norm_gen_log_prob(mu, 0, np.sqrt(sigma_sq_mu))

      z_one_hot = one_hot(z, len(pi))
      log_p_z = np.einsum('ij,j->', z_one_hot, np.log(pi))

      mu_z = np.einsum('ij,jk->ik', z_one_hot, mu)
      log_p_x = log_probs.norm_gen_log_prob(x, mu_z, np.sqrt(sigma_sq))

      return log_p_pi + log_p_z + log_p_mu + log_p_x

    n_clusters = 5
    n_dimensions = 2
    n_observations = 200

    alpha = 3.3 * np.ones(n_clusters)
    sigma_sq_mu = 1.5 ** 2
    sigma_sq = 0.5 ** 2

    np.random.seed(10001)

    pi = np.random.gamma(alpha)
    pi /= pi.sum()
    mu = np.random.normal(0, np.sqrt(sigma_sq_mu), [n_clusters, n_dimensions])
    z = np.random.choice(np.arange(n_clusters), size=n_observations, p=pi)
    x = np.random.normal(mu[z, :], sigma_sq)

    pi_est = np.ones(n_clusters) / n_clusters
    z_est = np.random.choice(np.arange(n_clusters), size=n_observations,
                             p=pi_est)
    mu_est = np.random.normal(0., 0.01, [n_clusters, n_dimensions])

    all_args = [x, pi_est, z_est, mu_est, sigma_sq, alpha, sigma_sq_mu]
    pi_posterior_args = all_args[:1] + all_args[2:]
    z_posterior_args = all_args[:2] + all_args[3:]
    mu_posterior_args = all_args[:3] + all_args[4:]

    pi_posterior = complete_conditional(log_joint, 1, SupportTypes.SIMPLEX,
                                        *all_args)
    z_posterior = complete_conditional(log_joint, 2, SupportTypes.INTEGER,
                                       *all_args)
    mu_posterior = complete_conditional(log_joint, 3, SupportTypes.REAL,
                                        *all_args)

    self.assertTrue(np.allclose(
        pi_posterior(*pi_posterior_args).alpha,
        alpha + np.histogram(z_est, np.arange(n_clusters+1))[0]))

    correct_z_logits = -0.5 / sigma_sq * np.square(x[:, :, None] -
                                                   mu_est.T[None, :, :]).sum(1)
    correct_z_logits += np.log(pi_est)
    correct_z_posterior = np.exp(correct_z_logits -
                                 misc.logsumexp(correct_z_logits, 1,
                                                keepdims=True))
    self.assertTrue(np.allclose(correct_z_posterior,
                                z_posterior(*z_posterior_args).p))

    correct_mu_posterior_mean = np.zeros_like(mu_est)
    correct_mu_posterior_var = np.zeros_like(mu_est)
    for k in range(n_clusters):
      n_k = (z_est == k).sum()
      correct_mu_posterior_var[k] = 1. / (1. / sigma_sq_mu + n_k / sigma_sq)
      correct_mu_posterior_mean[k] = (
          x[z_est == k].sum(0) / sigma_sq * correct_mu_posterior_var[k])
    mu_posterior_val = mu_posterior(*mu_posterior_args)
    self.assertTrue(np.allclose(correct_mu_posterior_mean,
                                mu_posterior_val.args[0]))
    self.assertTrue(np.allclose(correct_mu_posterior_var,
                                mu_posterior_val.args[1] ** 2))