def test_padded_where_dynamic_shape(self): """Tests that padded_where can deal with dynamic shapes.""" with tf.Graph().as_default(): condition = tf.placeholder(tf.bool, [None, None]) indices, mask = tensor_utils.padded_where(condition, 3) with tf.Session() as sess: # Test for single batch with 8 elements. results = sess.run( [indices, mask], feed_dict={condition: [[True] * 4 + [False] * 4]}) self.assertAllEqual(results[0], [[0, 1, 2]]) self.assertAllEqual(results[1], [[1, 1, 1]]) # Test for three batches with 6 elements each. results = sess.run( [indices, mask], feed_dict={ condition: [[False, True, False, True, True, True], [True, False, True, False, False, True], [True, True, False, False, True, False]] }) self.assertAllEqual(results[0], [[1, 3, 4], [0, 2, 5], [0, 1, 4]]) self.assertAllEqual(results[1], [[1, 1, 1], [1, 1, 1], [1, 1, 1]])
def test_padded_where_with_padding(self): """Basic padding test for padded_where.""" # Case where True condition values are in the left (beginning) of tensor. condition = tf.constant([True] * 4 + [False] * 4) indices, mask = tensor_utils.padded_where(condition, 6) self.assertAllEqual(indices, [0, 1, 2, 3, 0, 0]) self.assertAllEqual(mask, [1, 1, 1, 1, 0, 0]) # Case where True condition values are in the righ (end) of tensor. condition = tf.constant([False] * 4 + [True] * 4) indices, mask = tensor_utils.padded_where(condition, 6) self.assertAllEqual(indices, [4, 5, 6, 7, 0, 0]) self.assertAllEqual(mask, [1, 1, 1, 1, 0, 0]) # Case where True condition values non-contiguous. condition = tf.constant( [False, True, False, False, True, False, True, True]) indices, mask = tensor_utils.padded_where(condition, 6) self.assertAllEqual(indices, [1, 4, 6, 7, 0, 0]) self.assertAllEqual(mask, [1, 1, 1, 1, 0, 0]) # Case with batch dimention in the condition. condition = tf.constant([[False, True, False, True, True, True], [True, False, True, False, False, True], [True, False, False, False, True, False]]) indices, mask = tensor_utils.padded_where(condition, 6) self.assertAllEqual(indices, [ [1, 3, 4, 5, 0, 0], [0, 2, 5, 0, 0, 0], [0, 4, 0, 0, 0, 0], ]) self.assertAllEqual(mask, [ [1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 0, 0, 0, 0], ])
def test_padded_where_truncation(self): """Basic truncation test for padded_where.""" # Case where all condition values are True, in the left of the tensor. condition = tf.constant([True] * 4 + [False] * 4) indices, mask = tensor_utils.padded_where(condition, 3) self.assertAllEqual(indices, [0, 1, 2]) self.assertAllEqual(mask, [1, 1, 1]) # Case where True condition values are right-shifted. condition = tf.constant([False] * 4 + [True] * 4) indices, mask = tensor_utils.padded_where(condition, 3) self.assertAllEqual(indices, [4, 5, 6]) self.assertAllEqual(mask, [1, 1, 1]) # Case where True condition values non-contiguous. condition = tf.constant( [False, True, False, False, True, False, True, True]) indices, mask = tensor_utils.padded_where(condition, 3) self.assertAllEqual(indices, [1, 4, 6]) self.assertAllEqual(mask, [1, 1, 1]) # Case with batch dimention in the condition. condition = tf.constant([[False, True, False, True, True, True], [True, False, True, False, False, True], [True, True, False, False, True, False]]) indices, mask = tensor_utils.padded_where(condition, 3) self.assertAllEqual(indices, [ [1, 3, 4], [0, 2, 5], [0, 1, 4], ]) self.assertAllEqual(mask, [ [1, 1, 1], [1, 1, 1], [1, 1, 1], ])