def test_constant_samples(self): batch_size = 3 resolution = 64 n_channels = 3 batch = tf.concat([ tf.ones([1, resolution, resolution, n_channels]) * (batch + 1) for batch in range(batch_size) ], 0) mean = 2.0 expected_stddeviation = tf.sqrt( ((3.0 - mean)**2 + (2.0 - mean)**2 + (1.0 - mean)**2) / 3) layer = minibatch_standarddeviation.MiniBatchStandardDeviation() layer.build(batch.shape) expected = tf.concat( [batch, tf.broadcast_to(expected_stddeviation, batch.shape[:-1] + [1])], 3) difference = tf.abs(layer.call(batch) - expected) epsilon = 1e-5 is_equal = difference < epsilon self.assertTrue(tf.math.reduce_all(is_equal))
def test_invalid_batch_axis_argument(self): try: minibatch_standarddeviation.MiniBatchStandardDeviation( batch_axis="expects an integer") self.fail("An invalid channel_axis argument should cause a TypeError.") except TypeError: pass
def test_zero_deviation(self): batch_size = 8 resolution = 64 n_channels = 3 batch = tf.ones([batch_size, resolution, resolution, n_channels]) layer = minibatch_standarddeviation.MiniBatchStandardDeviation() layer.build(batch.shape) is_equal = tf.equal( layer.call(batch), tf.concat([batch, tf.broadcast_to(0.0, batch.shape[:-1] + [1])], 3)) self.assertTrue(tf.math.reduce_all(is_equal))
def test_random(self): batch_size = 3 resolution = 64 n_channels = 3 batch, expected = create_random_testcase( [batch_size, resolution, resolution, n_channels]) layer = minibatch_standarddeviation.MiniBatchStandardDeviation() layer.build(batch.shape) difference = tf.abs(layer.call(batch) - expected) epsilon = 1e-5 is_equal = difference < epsilon self.assertTrue(tf.math.reduce_all(is_equal))
def test_channel_axis_too_large(self): batch_size = 8 resolution = 64 n_channels = 3 batch = tf.ones([batch_size, resolution, resolution, n_channels]) layer = minibatch_standarddeviation.MiniBatchStandardDeviation( channel_axis=10) layer.build(batch.shape) try: layer.call(batch) self.fail("Too large a channel_axis argument should cause a ValueError.") except ValueError: pass
def test_config(self): layer = minibatch_standarddeviation.MiniBatchStandardDeviation( channel_axis=3, batch_axis=2) config = layer.get_config() expected_child_config = { "batch_axis": 2, "channel_axis": 3, } expected_super_config = super( minibatch_standarddeviation.MiniBatchStandardDeviation, layer).get_config() expected_config = dict( list(expected_child_config.items()) + list(expected_super_config.items())) self.assertEqual(config, expected_config)
def test_distributed_with_replica_ctx(self): batch_size = 3 resolution = 64 n_channels = 3 batch = tf.random.uniform([batch_size, resolution, resolution, n_channels]) layer = minibatch_standarddeviation.MiniBatchStandardDeviation() layer.build(batch.shape) non_distributed = layer.call(batch) strategy = tf.distribute.MirroredStrategy() layer_dist = minibatch_standarddeviation.SyncMiniBatchStandardDeviation() layer_dist.build(batch.shape) distributed = strategy.run(layer_dist.call, [ strategy.experimental_distribute_values_from_function(lambda _: batch) ]) difference = tf.abs(non_distributed - distributed) epsilon = 1e-5 is_equal = difference < epsilon self.assertTrue(tf.math.reduce_all(is_equal))
def test_distributed_without_replica_ctx(self): batch_size = 3 resolution = 64 n_channels = 3 batch = tf.random.uniform([batch_size, resolution, resolution, n_channels]) layer = minibatch_standarddeviation.MiniBatchStandardDeviation() layer.build(batch.shape) non_distributed = layer.call(batch) strategy = tf.distribute.MirroredStrategy() layer_dist = minibatch_standarddeviation.SyncMiniBatchStandardDeviation() layer_dist.build(batch.shape) with strategy.scope(): distributed = layer_dist.call(batch) difference = tf.abs(non_distributed - distributed) epsilon = 1e-5 is_equal = difference < epsilon self.assertTrue(tf.math.reduce_all(is_equal))