def main(argv):
  del argv  # Unused.

  n_examples = 10
  a = 1.3
  b = 2.2
  kappa = 1.5
  mu0 = 0.3
  tau = np.random.gamma(a, 1. / b)
  mu = np.random.normal(mu0, 1. / np.sqrt(tau * kappa))
  x = np.random.normal(mu, 1. / np.sqrt(tau), n_examples)
  all_args = [a, b, kappa, mu0, tau, mu, x]
  all_args_ex_mu = [a, b, kappa, mu0, tau, x]

  log_joint = ph.make_log_joint_fn(model)

  mu_conditional_factory = conjugacy.complete_conditional(
      log_joint, 5, conjugacy.SupportTypes.REAL, *all_args)
  mu_conditional = mu_conditional_factory(*all_args_ex_mu)

  log_p_tau = conjugacy.marginalize(log_joint, 5, conjugacy.SupportTypes.REAL,
                                    *all_args)
  tau_conditional_factory = conjugacy.complete_conditional(
      log_p_tau, 4, conjugacy.SupportTypes.NONNEGATIVE, *all_args_ex_mu)
  tau_conditional = tau_conditional_factory(*[a, b, kappa, mu0, x])

  print('True tau: {}'.format(tau))
  print('tau posterior is gamma({}, {}). Mean is {}, std. dev. is {}.'.format(
      tau_conditional.args[0], 1. / tau_conditional.args[2],
      tau_conditional.args[0] * tau_conditional.args[2],
      np.sqrt(tau_conditional.args[0]) * tau_conditional.args[2]))
  print()
  print('True mu: {}'.format(mu))
  print('mu posterior given tau is normal({}, {})'.format(
      mu_conditional.args[0], mu_conditional.args[1]))
def _condition_and_marginalize(log_joint, argnum, support, *args):
  sub_args = args[:argnum] + args[argnum + 1:]

  marginalized = marginalize(log_joint, argnum, support, *args)
  marginalized_value = marginalized(*sub_args)

  conditional_factory = complete_conditional(log_joint, argnum, support, *args)
  conditional = conditional_factory(*sub_args)

  return conditional, marginalized_value
def make_marginal_fn():
  x1_given_y1_factory = complete_conditional(
      log_p_x1_y1, 0, SupportTypes.REAL, *([1.] * 4))
  log_p_y1 = marginalize(log_p_x1_y1, 0, SupportTypes.REAL, *([1.] * 4))

  log_p_xtt_ytt = marginalize(
      log_p_xt_xtt_ytt, 0, SupportTypes.REAL, *([1.] * 7))
  log_p_ytt = marginalize(
      log_p_xtt_ytt, 0, SupportTypes.REAL, *([1.] * 6))
  xt_conditional_factory = complete_conditional(
      log_p_xtt_ytt, 0, SupportTypes.REAL, *([1.] * 6))

  def marginal(y_list, x_scale, y_scale):
    log_p_y = log_p_y1(y_list[0], x_scale, y_scale)
    xt_conditional = x1_given_y1_factory(y_list[0], x_scale, y_scale)

    for t in range(1, len(y_list)):
      log_p_y += log_p_ytt(y_list[t], xt_conditional.args[0],
                           xt_conditional.args[1], x_scale, y_scale)
      xt_conditional = xt_conditional_factory(
          y_list[t], xt_conditional.args[0], xt_conditional.args[1], x_scale,
          y_scale)
    return log_p_y
  return marginal
  def testTwoGaussians(self):
    def log_joint(x1, x2):
      log_p_x1 = -0.5 * x1 * x1
      x_diff = x2 - x1
      log_p_x2 = -0.5 * x_diff * x_diff
      return log_p_x1 + log_p_x2

    x1 = np.random.randn()
    x2 = x1 + np.random.randn()
    all_args = [x1, x2]

    marginal_p_x2 = marginalize(log_joint, 0, SupportTypes.REAL, *all_args)
    correct_marginalized_value = (
        -0.25 * x2 * x2 - 0.5 * np.log(2.) + 0.5 * np.log(2. * np.pi))
    self.assertAlmostEqual(correct_marginalized_value, marginal_p_x2(x2))

    x2_conditional = complete_conditional(marginal_p_x2, 0, SupportTypes.REAL,
                                          x2)()
    self.assertAlmostEqual(x2_conditional.args[0], 0.)
    self.assertAlmostEqual(x2_conditional.args[1] ** 2, 2.)