def test_where_select_axis_0(self): """Test for where() when selecting on axis 0.""" with self.test_session(graph=tf.Graph()): condition = tf.constant([True, False]) if_true = tf.constant([[1, 2], [3, 4]]) # pyformat: disable if_false = tf.constant([[5, 6], [7, 8]]) # pyformat: disable result = tensor_utils.where(condition, if_true, if_false) self.assertAllEqual(result.eval(), [[1, 2], [7, 8]]) # pyformat: disable
def test_where_select_nontrivial(self): """Test for where() when selecting on an intermediate axis.""" with self.test_session(graph=tf.Graph()): condition = tf.constant([[True, False], [False, True]]) # pyformat: disable if_true = tf.constant([[[1, 1], [2, 2]], [[3, 3], [4, 4]]]) # pyformat: disable if_false = tf.constant([[[5, 5], [6, 6]], [[7, 7], [8, 8]]]) # pyformat: disable result = tensor_utils.where(condition, if_true, if_false) self.assertAllEqual( result.eval(), [[[1, 1], [6, 6]], [[7, 7], [4, 4]]]) # pyformat: disable