def test_stream_squeeze_excitation(self): se = movinet_layers.StreamSqueezeExcitation(3, causal=True, kernel_initializer='ones') inputs = tf.range(4, dtype=tf.float32) + 1. inputs = tf.reshape(inputs, [1, 4, 1, 1, 1]) inputs = tf.tile(inputs, [1, 1, 2, 1, 3]) expected, _ = se(inputs) for num_splits in [1, 2, 4]: frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1) states = {} predicted = [] for frame in frames: x, states = se(frame, states=states) predicted.append(x) predicted = tf.concat(predicted, axis=1) self.assertEqual(predicted.shape, expected.shape) self.assertAllClose(predicted, expected, 1e-5, 1e-5) self.assertAllClose(predicted, [[[[[0.9998109, 0.9998109, 0.9998109]], [[0.9998109, 0.9998109, 0.9998109]]], [[[1.9999969, 1.9999969, 1.9999969]], [[1.9999969, 1.9999969, 1.9999969]]], [[[3., 3., 3.]], [[3., 3., 3.]]], [[[4., 4., 4.]], [[4., 4., 4.]]]]], 1e-5, 1e-5)
def test_stream_squeeze_excitation_2plus3d(self): se = movinet_layers.StreamSqueezeExcitation( 3, se_type='2plus3d', causal=True, activation='hard_swish', gating_activation='hard_sigmoid', kernel_initializer='ones') inputs = tf.range(4, dtype=tf.float32) + 1. inputs = tf.reshape(inputs, [1, 4, 1, 1, 1]) inputs = tf.tile(inputs, [1, 1, 2, 1, 3]) expected, _ = se(inputs) for num_splits in [1, 2, 4]: frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1) states = {} predicted = [] for frame in frames: x, states = se(frame, states=states) predicted.append(x) predicted = tf.concat(predicted, axis=1) self.assertEqual(predicted.shape, expected.shape) self.assertAllClose(predicted, expected, atol=1e-4) self.assertAllClose(predicted, [[[[[1., 1., 1.]], [[1., 1., 1.]]], [[[2., 2., 2.]], [[2., 2., 2.]]], [[[3., 3., 3.]], [[3., 3., 3.]]], [[[4., 4., 4.]], [[4., 4., 4.]]]]], atol=1e-4)