def testMakeInverseUpdateOpsNoEigenDecomp(self):
        with tf_ops.Graph().as_default(), self.test_session() as sess:
            random_seed.set_random_seed(200)
            cov = np.array([[5., 2.], [2.,
                                       4.]])  # NOTE(mattjj): must be symmetric
            factor = InverseProvidingFactorTestingDummy(cov.shape)
            factor._cov = array_ops.constant(cov, dtype=dtypes.float32)

            damping_func = make_damping_func(0)

            factor.register_inverse(damping_func)
            factor.instantiate_inv_variables()
            ops = factor.make_inverse_update_ops()
            self.assertEqual(1, len(ops))

            sess.run(tf_variables.global_variables_initializer())
            # The inverse op will assign the damped inverse of cov to the inv var.
            old_inv = sess.run(factor.get_inverse(damping_func))
            self.assertAllClose(
                sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)),
                old_inv)

            sess.run(ops)
            new_inv = sess.run(factor.get_inverse(damping_func))
            self.assertAllClose(new_inv, np.linalg.inv(cov))
  def testMakeInverseUpdateOpsNoEigenDecomp(self):
    with tf_ops.Graph().as_default(), self.test_session() as sess:
      random_seed.set_random_seed(200)
      cov = np.array([[5., 2.], [2., 4.]])  # NOTE(mattjj): must be symmetric
      factor = InverseProvidingFactorTestingDummy(cov.shape)
      factor._cov = array_ops.constant(cov, dtype=dtypes.float32)

      factor.register_damped_inverse(0)
      ops = factor.make_inverse_update_ops()
      self.assertEqual(1, len(ops))

      sess.run(tf_variables.global_variables_initializer())
      # The inverse op will assign the damped inverse of cov to the inv var.
      old_inv = sess.run(factor._inverses_by_damping[0])
      self.assertAllClose(
          sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)

      sess.run(ops)
      new_inv = sess.run(factor._inverses_by_damping[0])
      self.assertAllClose(new_inv, np.linalg.inv(cov))