def test_causal_stream_global_average_pool(self): gap = nn_layers.GlobalAveragePool3D(keepdims=True, causal=True) inputs = tf.range(4, dtype=tf.float32) + 1. inputs = tf.reshape(inputs, [1, 4, 1, 1, 1]) inputs = tf.tile(inputs, [1, 1, 2, 2, 3]) expected, _ = gap(inputs) for num_splits in [1, 2, 4]: frames = tf.split(inputs, num_splits, axis=1) states = {} predicted = [] for frame in frames: x, states = gap(frame, states=states) predicted.append(x) predicted = tf.concat(predicted, axis=1) self.assertEqual(predicted.shape, expected.shape) self.assertAllClose(predicted, expected) self.assertAllClose( predicted, [[[[[1.0, 1.0, 1.0]]], [[[1.5, 1.5, 1.5]]], [[[2.0, 2.0, 2.0]]], [[[2.5, 2.5, 2.5]]]]])
def test_global_average_pool_basic(self): pool = nn_layers.GlobalAveragePool3D(keepdims=True) inputs = tf.ones([1, 2, 3, 4, 1]) outputs = pool(inputs, output_states=False) expected = tf.ones([1, 1, 1, 1, 1]) self.assertEqual(outputs.shape, expected.shape) self.assertAllEqual(outputs, expected)
def test_global_average_pool_keras(self): pool = nn_layers.GlobalAveragePool3D(keepdims=False) keras_pool = tf.keras.layers.GlobalAveragePooling3D() inputs = 10 * tf.random.normal([1, 2, 3, 4, 1]) outputs = pool(inputs, output_states=False) keras_output = keras_pool(inputs) self.assertAllEqual(outputs.shape, keras_output.shape) self.assertAllClose(outputs, keras_output)