def test3DConvolution(self): with tf_ops.Graph().as_default(): batch_size = 1 width = 3 in_channels = 3**3 out_channels = 4 factor = ff.ConvInputKroneckerFactor( inputs=(random_ops.random_uniform( (batch_size, width, width, width, in_channels), seed=0), ), filter_shape=(width, width, width, in_channels, out_channels), padding='SAME', strides=(2, 2, 2), extract_patches_fn='extract_convolution_patches', has_bias=False) factor.instantiate_cov_variables() # Ensure shape of covariance matches input size of filter. input_size = in_channels * (width**3) self.assertEqual([input_size, input_size], factor.get_cov().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.cached_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) cov = sess.run(factor.get_cov()) # Cov should be rank-8, as the filter will be applied at each corner of # the 4-D cube. self.assertMatrixRank(8, cov)
def testDilationRate(self): with tf_ops.Graph().as_default(): batch_size = 1 width = 3 in_channels = 2 out_channels = 4 factor = ff.ConvInputKroneckerFactor( inputs=(random_ops.random_uniform( (batch_size, width, width, in_channels), seed=0), ), filter_shape=(3, 3, in_channels, out_channels), padding='SAME', extract_patches_fn='extract_image_patches', strides=(1, 1, 1, 1), dilation_rate=(1, width, width, 1), has_bias=False) factor.instantiate_cov_variables() with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) cov = sess.run(factor.get_cov_var()) # Cov should be rank = in_channels, as only the center of the filter # receives non-zero input for each input channel. self.assertMatrixRank(in_channels, cov)
def testPointwiseConv2d(self): with tf_ops.Graph().as_default(): batch_size = 1 width = 3 in_channels = 3**2 out_channels = 4 factor = ff.ConvInputKroneckerFactor( inputs=(random_ops.random_uniform( (batch_size, width, width, in_channels), seed=0), ), filter_shape=(1, 1, in_channels, out_channels), padding='SAME', strides=(1, 1, 1, 1), extract_patches_fn='extract_pointwise_conv2d_patches', has_bias=False) factor.instantiate_cov_variables() # Ensure shape of covariance matches input size of filter. self.assertEqual([in_channels, in_channels], factor.get_cov_var().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) cov = sess.run(factor.get_cov_var()) # Cov should be rank-9, as the filter will be applied at each location. self.assertMatrixRank(9, cov)
def testConvInputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( tensor, (1, 2, 3, 4), 3, 2, has_bias=True) self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list())
def testConvInputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list())
def testConvInputKroneckerFactorInitNoBias(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 3, 4), 3, 2, has_bias=False) factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3, 1 * 2 * 3], factor.get_cov().get_shape().as_list())
def testMakeCovarianceUpdateOpNoBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant( np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME') sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) self.assertAllClose([[34.375, 37], [37, 41]], new_cov)
def testConvInputKroneckerFactorInitFloat64(self): with tf_ops.Graph().as_default(): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.ConvInputKroneckerFactor( tensor, (1, 2, 3, 4), 3, 2, has_bias=True) cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], cov.get_shape().as_list())
def testMakeCovarianceUpdateOpNoBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: input_shape = (2, 1, 1, 1) tensor = array_ops.constant( np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( np.float32)) factor = ff.ConvInputKroneckerFactor( (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(0.)) self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
def testConvInputKroneckerFactorInitFloat64(self): with tf_ops.Graph().as_default(): dtype = dtypes.float64_ref tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) factor = ff.ConvInputKroneckerFactor(tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], cov.get_shape().as_list())
def testStrides(self): with tf_ops.Graph().as_default(): batch_size = 1 width = 3 in_channels = 3**2 out_channels = 4 factor = ff.ConvInputKroneckerFactor( inputs=(random_ops.random_uniform( (batch_size, width, width, in_channels), seed=0), ), filter_shape=(1, 1, in_channels, out_channels), padding='SAME', strides=(1, 2, 1, 1), extract_patches_fn='extract_image_patches', has_bias=False) factor.instantiate_cov_variables() with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) cov = sess.run(factor.get_cov_var()) # Cov should be the sum of 3 * 2 = 6 outer products. self.assertMatrixRank(6, cov)