def test_stream_squeeze_excitation(self):
        se = movinet_layers.StreamSqueezeExcitation(3,
                                                    causal=True,
                                                    kernel_initializer='ones')

        inputs = tf.range(4, dtype=tf.float32) + 1.
        inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
        inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
        expected, _ = se(inputs)

        for num_splits in [1, 2, 4]:
            frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
            states = {}
            predicted = []
            for frame in frames:
                x, states = se(frame, states=states)
                predicted.append(x)
            predicted = tf.concat(predicted, axis=1)

            self.assertEqual(predicted.shape, expected.shape)
            self.assertAllClose(predicted, expected, 1e-5, 1e-5)

            self.assertAllClose(predicted,
                                [[[[[0.9998109, 0.9998109, 0.9998109]],
                                   [[0.9998109, 0.9998109, 0.9998109]]],
                                  [[[1.9999969, 1.9999969, 1.9999969]],
                                   [[1.9999969, 1.9999969, 1.9999969]]],
                                  [[[3., 3., 3.]], [[3., 3., 3.]]],
                                  [[[4., 4., 4.]], [[4., 4., 4.]]]]], 1e-5,
                                1e-5)
Beispiel #2
0
    def test_stream_squeeze_excitation_2plus3d(self):
        se = movinet_layers.StreamSqueezeExcitation(
            3,
            se_type='2plus3d',
            causal=True,
            activation='hard_swish',
            gating_activation='hard_sigmoid',
            kernel_initializer='ones')

        inputs = tf.range(4, dtype=tf.float32) + 1.
        inputs = tf.reshape(inputs, [1, 4, 1, 1, 1])
        inputs = tf.tile(inputs, [1, 1, 2, 1, 3])
        expected, _ = se(inputs)

        for num_splits in [1, 2, 4]:
            frames = tf.split(inputs, inputs.shape[1] // num_splits, axis=1)
            states = {}
            predicted = []
            for frame in frames:
                x, states = se(frame, states=states)
                predicted.append(x)
            predicted = tf.concat(predicted, axis=1)

            self.assertEqual(predicted.shape, expected.shape)
            self.assertAllClose(predicted, expected, atol=1e-4)

            self.assertAllClose(predicted,
                                [[[[[1., 1., 1.]], [[1., 1., 1.]]],
                                  [[[2., 2., 2.]], [[2., 2., 2.]]],
                                  [[[3., 3., 3.]], [[3., 3., 3.]]],
                                  [[[4., 4., 4.]], [[4., 4., 4.]]]]],
                                atol=1e-4)