def testSqrtToPow(self):
   def fun(x):
     return np.sqrt(x)
   expr = make_expr(fun, 3.)
   expr = canonicalize(expr)
   self.assertIsInstance(expr, GraphExpr)
   self.assertEqual(expr.expr_node.fun, np.power)
   self.assertEqual(eval_expr(expr, {'x': 3.}), fun(3.))
  def testFindSufficientStatisticNodes(self):

    def log_joint(x, y, matrix):
      # Linear in x: y^T x
      result = np.einsum('i,i->', x, y)
      # Quadratic form: x^T matrix x
      result += np.einsum('ij,i,j->', matrix, x, x)
      # Rank-1 quadratic form: (x**2)^T(y**2)
      result += np.einsum('i,i,j,j->', x, y, x, y)
      # Linear in log(x): y^T log(x)
      result += np.einsum('i,i->', y, np.log(x))
      # Linear in reciprocal(x): y^T reciprocal(x)
      result += np.einsum('i,i->', y, np.reciprocal(x))
      # More obscurely linear in log(x): y^T matrix log(x)
      result += np.einsum('i,ij,j->', y, matrix, np.log(x))
      # Linear in x * log(x): y^T (x * log(x))
      result += np.einsum('i,i->', y, x * np.log(x))
      return result

    n_dimensions = 5
    x = np.exp(np.random.randn(n_dimensions))
    y = np.random.randn(n_dimensions)
    matrix = np.random.randn(n_dimensions, n_dimensions)
    env = {'x': x, 'y': y, 'matrix': matrix}

    expr = make_expr(log_joint, x, y, matrix)
    expr = canonicalize(expr)
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(expr, 'x')
    suff_stats = [eval_expr(GraphExpr(node, expr.free_vars), env)
                  for node in sufficient_statistic_nodes]
    correct_suff_stats = [x, x.dot(matrix.dot(x)), np.square(x.dot(y)),
                          np.log(x), np.reciprocal(x), y.dot(x * np.log(x))]
    self.assertTrue(_perfect_match_values(suff_stats, correct_suff_stats))

    expr = make_expr(log_joint, x, y, matrix)
    expr = canonicalize(expr)
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(
        expr, 'x', split_einsums=True)
    suff_stats = [eval_expr(GraphExpr(node, expr.free_vars), env)
                  for node in sufficient_statistic_nodes]
    correct_suff_stats = [x, np.outer(x, x), x * x,
                          np.log(x), np.reciprocal(x), x * np.log(x)]
    self.assertTrue(_match_values(suff_stats, correct_suff_stats))
  def testEinsumAddSubSimplify(self):
    # TODO(mhoffman): Think about broadcasting. We need to support `x - 2.0`.

    def test_fun(x):
      return np.einsum('i->', x + np.full(x.shape, 2.0))

    expr = make_expr(test_fun, np.ones(3))
    test_x = np.full(3, 0.5)
    correct_value = eval_expr(expr, {'x': test_x})
    expr = canonicalize(expr)
    self.assertIsInstance(expr, GraphExpr)
    self.assertEqual(expr.expr_node.fun, add_n)
    self.assertEqual(expr.expr_node.parents[0].fun.__name__, 'einsum')
    new_value = eval_expr(expr, {'x': test_x})
    self.assertEqual(correct_value, new_value)
  def testCanonicalize(self):

    def mahalanobis_distance(x, y, matrix):
      x_minus_y = x - y
      return np.einsum('i,j,ij->', x_minus_y, x_minus_y, matrix)

    x = np.array([1.3, 3.6])
    y = np.array([2.3, -1.2])
    matrix = np.arange(4).reshape([2, 2])
    expr = make_expr(mahalanobis_distance, x, y, matrix)
    self.assertFalse(is_canonical(expr))
    correct_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix})
    expr = canonicalize(expr)
    self.assertTrue(is_canonical(expr))
    new_value = eval_expr(expr, {'x': x, 'y': y, 'matrix': matrix})
    self.assertAlmostEqual(correct_value, new_value)
  def testLinearRegression(self):
    def log_joint(X, beta, y):
      predictions = np.einsum('ij,j->i', X, beta)
      errors = y - predictions
      log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta)
      log_likelihood = np.einsum(',k,k->', -0.5, errors, errors)
      return log_prior + log_likelihood
    n_examples = 10
    n_predictors = 2
    X = np.random.randn(n_examples, n_predictors)
    beta = np.random.randn(n_predictors)
    y = np.random.randn(n_examples)
    graph = make_expr(log_joint, X, beta, y)
    graph = canonicalize(graph)

    args = graph.free_vars.keys()
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1])
    sufficient_statistics = [eval_node(node, graph.free_vars,
                                       {'X': X, 'beta': beta, 'y': y})
                             for node in sufficient_statistic_nodes]
    correct_sufficient_statistics = [
        -0.5 * beta.dot(beta), beta,
        -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta)
    ]
    self.assertTrue(_match_values(sufficient_statistics,
                                  correct_sufficient_statistics))

    _, natural_parameter_funs = _extract_conditional_factors(graph, 'beta')
    self.assertTrue(_match_values(natural_parameter_funs.keys(),
                                  ['x', 'einsum(...a,...b->...ab, x, x)',
                                   'einsum(...,...->..., x, x)'],
                                  lambda x, y: x == y))
    natural_parameter_vals = [f(X, beta, y) for f in
                              natural_parameter_funs.values()]
    correct_parameter_vals = [-0.5 * np.ones(n_predictors), -0.5 * X.T.dot(X),
                              y.dot(X)]
    self.assertTrue(_match_values(natural_parameter_vals,
                                  correct_parameter_vals))

    conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL,
                                               X, beta, y)
    conditional = conditional_factory(X, y)
    true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors))
    true_mean = true_cov.dot(y.dot(X))
    self.assertTrue(np.allclose(true_cov, conditional.cov))
    self.assertTrue(np.allclose(true_mean, conditional.mean))
  def _test_condition_and_marginalize_diagonal_zero_mean_normal(self,
                                                                log_joint):
    n_dimensions = 5
    x = np.random.randn(n_dimensions)
    tau = np.random.randn(n_dimensions) ** 2

    end_node = make_expr(log_joint, x, tau)
    end_node = canonicalize(end_node)

    conditional, marginalized_value = _condition_and_marginalize(
        log_joint, 0, SupportTypes.REAL, x, tau)
    correct_marginalized_value = (-0.5 * np.log(tau).sum()
                                  + 0.5 * n_dimensions * np.log(2. * np.pi))
    self.assertAlmostEqual(correct_marginalized_value, marginalized_value)

    self.assertTrue(np.allclose(np.zeros(n_dimensions), conditional.args[0]))
    self.assertTrue(np.allclose(1. / np.sqrt(tau), conditional.args[1]))
 def testLinearRegression(self):
   def squared_loss(X, beta, y):
     predictions = np.einsum('ij,j->i', X, beta)
     errors = y - predictions
     return np.einsum('k,k->', errors, errors)
   n_examples = 10
   n_predictors = 2
   X = np.random.randn(n_examples, n_predictors)
   beta = np.random.randn(n_predictors)
   y = np.random.randn(n_examples)
   expr = make_expr(squared_loss, X, beta, y)
   correct_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y})
   self.assertFalse(is_canonical(expr))
   expr = canonicalize(expr)
   self.assertTrue(is_canonical(expr))
   new_value = eval_expr(expr, {'X': X, 'beta': beta, 'y':y})
   self.assertAlmostEqual(correct_value, new_value)
 def testEinsumCompose(self):
   def Xbeta_squared(X, beta):
     Xbeta = np.einsum('ij,j->i', X, beta)
     Xbeta2 = np.einsum('lm,m->l', X, beta)
     return np.einsum('k,k->', Xbeta, Xbeta)
   n_examples = 10
   n_predictors = 2
   X = np.random.randn(n_examples, n_predictors)
   beta = np.random.randn(n_predictors)
   expr = make_expr(Xbeta_squared, X, beta)
   correct_value = eval_expr(expr, {'X': X, 'beta': beta})
   self.assertFalse(is_canonical(expr))
   expr = canonicalize(expr)
   new_value = eval_expr(expr, {'X': X, 'beta': beta})
   self.assertAlmostEqual(correct_value, new_value)
   self.assertIsInstance(expr, GraphExpr)
   self.assertEqual(expr.expr_node.fun, np.einsum)
   self.assertTrue(is_canonical(expr))
  def testLinearRegression(self):
    def log_joint(X, beta, y):
      predictions = np.einsum('ij,j->i', X, beta)
      errors = y - predictions
      log_prior = np.einsum('i,i,i->', -0.5 * np.ones_like(beta), beta, beta)
      log_likelihood = np.einsum(',k,k->', -0.5, errors, errors)
      return log_prior + log_likelihood
    n_examples = 10
    n_predictors = 2
    X = np.random.randn(n_examples, n_predictors)
    beta = np.random.randn(n_predictors)
    y = np.random.randn(n_examples)
    graph = make_expr(log_joint, X, beta, y)
    graph = canonicalize(graph)

    args = graph.free_vars.keys()
    sufficient_statistic_nodes = find_sufficient_statistic_nodes(graph, args[1])
    sufficient_statistics = [eval_node(node, graph.free_vars,
                                       {'X': X, 'beta': beta, 'y': y})
                             for node in sufficient_statistic_nodes]
    correct_sufficient_statistics = [
        -0.5 * beta.dot(beta), beta,
        -0.5 * np.einsum('ij,ik,j,k', X, X, beta, beta)
    ]
    self.assertTrue(_match_values(sufficient_statistics,
                                  correct_sufficient_statistics))

    new_log_joint, _, stats_funs, _ = (
        statistic_representation(log_joint, (X, beta, y),
                               (SupportTypes.REAL,), (1,)))
    beta_stat_fun = stats_funs[0]
    beta_natparam = grad_namedtuple(new_log_joint, 1)(X, beta_stat_fun(beta), y)
    correct_beta_natparam = (-0.5 * X.T.dot(X), y.dot(X),
                             -0.5 * np.ones(n_predictors))
    self.assertTrue(_match_values(beta_natparam, correct_beta_natparam))

    conditional_factory = complete_conditional(log_joint, 1, SupportTypes.REAL,
                                               X, beta, y)
    conditional = conditional_factory(X, y)
    true_cov = np.linalg.inv(X.T.dot(X) + np.eye(n_predictors))
    true_mean = true_cov.dot(y.dot(X))
    self.assertTrue(np.allclose(true_cov, conditional.cov))
    self.assertTrue(np.allclose(true_mean, conditional.mean))