Ejemplo n.º 1
0
    def testDilationRate(self):
        with tf.Graph().as_default():
            batch_size = 1
            width = 3
            in_channels = 2
            out_channels = 4

            factor = ff.ConvInputKroneckerFactor(
                inputs=(tf.random_uniform(
                    (batch_size, width, width, in_channels), seed=0), ),
                filter_shape=(3, 3, in_channels, out_channels),
                padding='SAME',
                extract_patches_fn='extract_image_patches',
                strides=(1, 1, 1, 1),
                dilation_rate=(1, width, width, 1),
                has_bias=False)
            factor.instantiate_cov_variables()

            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                sess.run(factor.make_covariance_update_op(0.0))
                cov = sess.run(factor.cov)

            # Cov should be rank = in_channels, as only the center of the filter
            # receives non-zero input for each input channel.
            self.assertMatrixRank(in_channels, cov)
Ejemplo n.º 2
0
    def testPointwiseConv2d(self):
        with tf.Graph().as_default():
            batch_size = 1
            width = 3
            in_channels = 3**2
            out_channels = 4

            factor = ff.ConvInputKroneckerFactor(
                inputs=(tf.random_uniform(
                    (batch_size, width, width, in_channels), seed=0), ),
                filter_shape=(1, 1, in_channels, out_channels),
                padding='SAME',
                strides=(1, 1, 1, 1),
                extract_patches_fn='extract_pointwise_conv2d_patches',
                has_bias=False)
            factor.instantiate_cov_variables()

            # Ensure shape of covariance matches input size of filter.
            self.assertEqual([in_channels, in_channels],
                             factor.cov.shape.as_list())

            # Ensure cov_update_op doesn't crash.
            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                sess.run(factor.make_covariance_update_op(0.0))
                cov = sess.run(factor.cov)

            # Cov should be rank-9, as the filter will be applied at each location.
            self.assertMatrixRank(9, cov)
Ejemplo n.º 3
0
    def test3DConvolution(self):
        with tf.Graph().as_default():
            batch_size = 1
            width = 3
            in_channels = 3**3
            out_channels = 4

            factor = ff.ConvInputKroneckerFactor(
                inputs=(tf.random_uniform(
                    (batch_size, width, width, width, in_channels), seed=0), ),
                filter_shape=(width, width, width, in_channels, out_channels),
                padding='SAME',
                strides=(2, 2, 2),
                extract_patches_fn='extract_convolution_patches',
                has_bias=False)
            factor.instantiate_cov_variables()

            # Ensure shape of covariance matches input size of filter.
            input_size = in_channels * (width**3)
            self.assertEqual([input_size, input_size],
                             factor.get_cov_var().shape.as_list())

            # Ensure cov_update_op doesn't crash.
            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                sess.run(factor.make_covariance_update_op(0.0))
                cov = sess.run(factor.get_cov_var())

            # Cov should be rank-8, as the filter will be applied at each corner of
            # the 4-D cube.
            self.assertMatrixRank(8, cov)
Ejemplo n.º 4
0
 def testConvInputKroneckerFactorInit(self):
     with tf.Graph().as_default():
         tensor = tf.ones((64, 1, 2, 3), name='a/b/c')
         factor = ff.ConvInputKroneckerFactor((tensor, ),
                                              filter_shape=(1, 2, 3, 4),
                                              padding='SAME',
                                              has_bias=True)
         factor.instantiate_cov_variables()
         self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
                          factor.cov.get_shape().as_list())
Ejemplo n.º 5
0
 def testConvInputKroneckerFactorInitFloat64(self):
     with tf.Graph().as_default():
         dtype = dtypes.float64_ref
         tensor = tf.ones((64, 1, 2, 3), name='a/b/c', dtype=tf.float64)
         factor = ff.ConvInputKroneckerFactor((tensor, ),
                                              filter_shape=(1, 2, 3, 4),
                                              padding='SAME',
                                              has_bias=True)
         factor.instantiate_cov_variables()
         cov = factor.cov
         self.assertEqual(cov.dtype, dtype)
         self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
                          cov.get_shape().as_list())
Ejemplo n.º 6
0
    def testMakeCovarianceUpdateOpNoBias(self):
        with tf.Graph().as_default(), self.test_session() as sess:
            input_shape = (2, 1, 1, 1)
            tensor = tf.constant(
                np.arange(1, 1 +
                          np.prod(input_shape)).reshape(input_shape).astype(
                              np.float32))
            factor = ff.ConvInputKroneckerFactor((tensor, ),
                                                 filter_shape=(1, 1, 1, 1),
                                                 padding='SAME')
            factor.instantiate_cov_variables()

            sess.run(tf.global_variables_initializer())
            new_cov = sess.run(factor.make_covariance_update_op(0.))
            self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
Ejemplo n.º 7
0
    def testStrides(self):
        with tf.Graph().as_default():
            batch_size = 1
            width = 3
            in_channels = 3**2
            out_channels = 4

            factor = ff.ConvInputKroneckerFactor(
                inputs=(tf.random_uniform(
                    (batch_size, width, width, in_channels), seed=0), ),
                filter_shape=(1, 1, in_channels, out_channels),
                padding='SAME',
                strides=(1, 2, 1, 1),
                extract_patches_fn='extract_image_patches',
                has_bias=False)
            factor.instantiate_cov_variables()

            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                sess.run(factor.make_covariance_update_op(0.0))
                cov = sess.run(factor.cov)

            # Cov should be the sum of 3 * 2 = 6 outer products.
            self.assertMatrixRank(6, cov)