def testNaiveDiagonalFactorInit(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor, ), 32) factor.instantiate_cov_variables() self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([1., 2.], name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 2) sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) self.assertAllClose([[0.75], [1.5]], new_cov)
def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list())