    def test_padded_where_dynamic_shape(self):
        """Tests that padded_where can deal with dynamic shapes."""
        with tf.Graph().as_default():
            condition = tf.placeholder(tf.bool, [None, None])
            indices, mask = tensor_utils.padded_where(condition, 3)

            with tf.Session() as sess:
                # Test for single batch with 8 elements.
                results = sess.run(
                    [indices, mask],
                    feed_dict={condition: [[True] * 4 + [False] * 4]})
                self.assertAllEqual(results[0], [[0, 1, 2]])
                self.assertAllEqual(results[1], [[1, 1, 1]])

                # Test for three batches with 6 elements each.
                results = sess.run(
                    [indices, mask],
                        condition: [[False, True, False, True, True, True],
                                    [True, False, True, False, False, True],
                                    [True, True, False, False, True, False]]
                                    [[1, 3, 4], [0, 2, 5], [0, 1, 4]])
                                    [[1, 1, 1], [1, 1, 1], [1, 1, 1]])
    def test_padded_where_with_padding(self):
        """Basic padding test for padded_where."""
        # Case where True condition values are in the left (beginning) of tensor.
        condition = tf.constant([True] * 4 + [False] * 4)
        indices, mask = tensor_utils.padded_where(condition, 6)

        self.assertAllEqual(indices, [0, 1, 2, 3, 0, 0])
        self.assertAllEqual(mask, [1, 1, 1, 1, 0, 0])

        # Case where True condition values are in the righ (end) of tensor.
        condition = tf.constant([False] * 4 + [True] * 4)
        indices, mask = tensor_utils.padded_where(condition, 6)

        self.assertAllEqual(indices, [4, 5, 6, 7, 0, 0])
        self.assertAllEqual(mask, [1, 1, 1, 1, 0, 0])

        # Case where True condition values non-contiguous.
        condition = tf.constant(
            [False, True, False, False, True, False, True, True])
        indices, mask = tensor_utils.padded_where(condition, 6)

        self.assertAllEqual(indices, [1, 4, 6, 7, 0, 0])
        self.assertAllEqual(mask, [1, 1, 1, 1, 0, 0])

        # Case with batch dimention in the condition.
        condition = tf.constant([[False, True, False, True, True, True],
                                 [True, False, True, False, False, True],
                                 [True, False, False, False, True, False]])
        indices, mask = tensor_utils.padded_where(condition, 6)

        self.assertAllEqual(indices, [
            [1, 3, 4, 5, 0, 0],
            [0, 2, 5, 0, 0, 0],
            [0, 4, 0, 0, 0, 0],
        self.assertAllEqual(mask, [
            [1, 1, 1, 1, 0, 0],
            [1, 1, 1, 0, 0, 0],
            [1, 1, 0, 0, 0, 0],
    def test_padded_where_truncation(self):
        """Basic truncation test for padded_where."""
        # Case where all condition values are True, in the left of the tensor.
        condition = tf.constant([True] * 4 + [False] * 4)
        indices, mask = tensor_utils.padded_where(condition, 3)

        self.assertAllEqual(indices, [0, 1, 2])
        self.assertAllEqual(mask, [1, 1, 1])

        # Case where True condition values are right-shifted.
        condition = tf.constant([False] * 4 + [True] * 4)
        indices, mask = tensor_utils.padded_where(condition, 3)

        self.assertAllEqual(indices, [4, 5, 6])
        self.assertAllEqual(mask, [1, 1, 1])

        # Case where True condition values non-contiguous.
        condition = tf.constant(
            [False, True, False, False, True, False, True, True])
        indices, mask = tensor_utils.padded_where(condition, 3)

        self.assertAllEqual(indices, [1, 4, 6])
        self.assertAllEqual(mask, [1, 1, 1])

        # Case with batch dimention in the condition.
        condition = tf.constant([[False, True, False, True, True, True],
                                 [True, False, True, False, False, True],
                                 [True, True, False, False, True, False]])
        indices, mask = tensor_utils.padded_where(condition, 3)

        self.assertAllEqual(indices, [
            [1, 3, 4],
            [0, 2, 5],
            [0, 1, 4],
        self.assertAllEqual(mask, [
            [1, 1, 1],
            [1, 1, 1],
            [1, 1, 1],