def test_none_channels_mean(): # 3-D image fn = mean_filter2d.get_concrete_function( tf.TensorSpec(dtype=tf.dtypes.float32, shape=(3, 3, None))) fn(tf.ones(shape=(3, 3, 1))) fn(tf.ones(shape=(3, 3, 3))) # 4-D image fn = mean_filter2d.get_concrete_function( tf.TensorSpec(dtype=tf.dtypes.float32, shape=(1, 3, 3, None))) fn(tf.ones(shape=(1, 3, 3, 1))) fn(tf.ones(shape=(1, 3, 3, 3)))
def test_unknown_shape_mean(shape): fn = mean_filter2d.get_concrete_function( tf.TensorSpec(shape=None, dtype=tf.dtypes.float32), padding="CONSTANT", constant_values=1.0, ) image = tf.ones(shape=shape) np.testing.assert_equal(image.numpy(), fn(image).numpy())
def test_unknown_shape(self): fn = mean_filter2d.get_concrete_function(tf.TensorSpec( shape=None, dtype=tf.dtypes.float32), padding="CONSTANT", constant_values=1.) for shape in [(3, 3), (3, 3, 3), (1, 3, 3, 3)]: image = tf.ones(shape=shape) self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image)))