def test_stream_conv_block_3d_2plus1d(self): conv_block = movinet_layers.ConvBlock( filters=3, kernel_size=(3, 3, 3), strides=(1, 2, 2), causal=True, kernel_initializer='ones', use_bias=False, activation='relu', conv_type='3d_2plus1d', use_positional_encoding=True, ) stream_conv_block = movinet_layers.StreamConvBlock( filters=3, kernel_size=(3, 3, 3), strides=(1, 2, 2), causal=True, kernel_initializer='ones', use_bias=False, activation='relu', conv_type='3d_2plus1d', use_positional_encoding=True, ) inputs = tf.ones([1, 4, 2, 2, 3]) expected = conv_block(inputs) predicted_disabled, _ = stream_conv_block(inputs) self.assertEqual(predicted_disabled.shape, expected.shape) self.assertAllClose(predicted_disabled, expected) for num_splits in [1, 2, 4]: frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1) states = {} predicted = [] for frame in frames: x, states = stream_conv_block(frame, states=states) predicted.append(x) predicted = tf.concat(predicted, axis=1) self.assertEqual(predicted.shape, expected.shape) self.assertAllClose(predicted, expected) self.assertAllClose(predicted, [[[[[35.9640400, 35.9640400, 35.9640400]]], [[[71.9280700, 71.9280700, 71.9280700]]], [[[107.892105, 107.892105, 107.892105]]], [[[107.892105, 107.892105, 107.892105]]]]])
def test_stream_conv_block(self): conv_block = movinet_layers.ConvBlock( filters=3, kernel_size=(3, 3, 3), strides=(1, 2, 2), causal=True, kernel_initializer='ones', use_bias=False, activation='relu', ) stream_conv_block = movinet_layers.StreamConvBlock( filters=3, kernel_size=(3, 3, 3), strides=(1, 2, 2), causal=True, kernel_initializer='ones', use_bias=False, activation='relu', ) inputs = tf.ones([1, 4, 2, 2, 3]) expected = conv_block(inputs) predicted_disabled, _ = stream_conv_block(inputs) self.assertEqual(predicted_disabled.shape, expected.shape) self.assertAllClose(predicted_disabled, expected) for num_splits in [1, 2, 4]: frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1) states = {} predicted = [] for frame in frames: x, states = stream_conv_block(frame, states=states) predicted.append(x) predicted = tf.concat(predicted, axis=1) self.assertEqual(predicted.shape, expected.shape) self.assertAllClose(predicted, expected) self.assertAllClose( predicted, [[[[[11.994005, 11.994005, 11.994005]]], [[[23.988010, 23.988010, 23.988010]]], [[[35.982014, 35.982014, 35.982014]]], [[[35.982014, 35.982014, 35.982014]]]]])