def test_pad_to_multiple_1d(self):
        tensor = tf.range(3) + 1

        self.assertAllEqual([1, 2, 3],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=3,
                                                         axis=0))

        self.assertAllEqual([1, 2, 3, 0, 0],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=5,
                                                         axis=0))

        self.assertAllEqual([1, 2, 3, 0, 0, 0, 0],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=7,
                                                         axis=0))

        self.assertAllEqual([1, 2, 3, 0],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=2,
                                                         axis=0))

        self.assertAllEqual([1, 2, 3],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=1,
                                                         axis=0))
    def test_pad_to_multiple_static_shape(self):
        # We use `placeholder_with_default` to simulate the TF v1 situation where
        # a static `batch_size` is unknown.
        tensor = tf.compat.v1.placeholder_with_default(np.ones(shape=[2, 5],
                                                               dtype=np.int32),
                                                       shape=[None, 5])

        result = tensor_utils.pad_to_multiple(tensor, factor=3, axis=-1)

        static_batch_size = tensor.shape.as_list()[0]
        self.assertAllEqual([static_batch_size, 6], result.shape.as_list())
    def test_pad_to_multiple_padding_mode(self):
        tensor = tf.range(3) + 1

        self.assertAllEqual([1, 2, 3, 2, 1],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=5,
                                                         axis=0,
                                                         mode='REFLECT'))

        self.assertAllEqual([1, 2, 3, 3, 2],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=5,
                                                         axis=0,
                                                         mode='SYMMETRIC'))

        self.assertAllEqual([1, 2, 3, -1, -1],
                            tensor_utils.pad_to_multiple(tensor,
                                                         factor=5,
                                                         axis=0,
                                                         constant_values=-1))
    def test_pad_to_multiple_3d(self):
        tensor = tf.ones([2, 3, 5], dtype=tf.float32)

        self.assertAllEqual(
            [
                [
                    [1, 1, 1, 1, 1],  #
                    [1, 1, 1, 1, 1],  #
                    [1, 1, 1, 1, 1],  #
                ],  #
                [
                    [1, 1, 1, 1, 1],  #
                    [1, 1, 1, 1, 1],  #
                    [1, 1, 1, 1, 1],  #
                ],  #
                [
                    [0, 0, 0, 0, 0],  #
                    [0, 0, 0, 0, 0],  #
                    [0, 0, 0, 0, 0],  #
                ]
            ],
            tensor_utils.pad_to_multiple(tensor, factor=3, axis=0))

        self.assertAllEqual(
            tensor, tensor_utils.pad_to_multiple(tensor, factor=3, axis=-2))

        self.assertAllEqual(
            [
                [
                    [1, 1, 1, 1, 1, 0],  #
                    [1, 1, 1, 1, 1, 0],  #
                    [1, 1, 1, 1, 1, 0],  #
                ],  #
                [
                    [1, 1, 1, 1, 1, 0],  #
                    [1, 1, 1, 1, 1, 0],  #
                    [1, 1, 1, 1, 1, 0],  #
                ]
            ],
            tensor_utils.pad_to_multiple(tensor, factor=3, axis=-1))