def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.FullyConnectedKroneckerFactor((tensor, ), has_bias=has_bias) self.assertEqual(final_shape, factor.get_cov().get_shape().as_list())
def testMakeCovarianceUpdateOpNoBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') factor = ff.FullyConnectedKroneckerFactor((tensor,)) sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape, dtype=dtypes.float32_ref): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual(final_shape, cov.get_shape().as_list())