Ejemplo n.º 1
0
    def test_where_select_axis_0(self):
        """Test for where() when selecting on axis 0."""
        with self.test_session(graph=tf.Graph()):
            condition = tf.constant([True, False])
            if_true = tf.constant([[1, 2], [3, 4]])  # pyformat: disable
            if_false = tf.constant([[5, 6], [7, 8]])  # pyformat: disable

            result = tensor_utils.where(condition, if_true, if_false)

            self.assertAllEqual(result.eval(),
                                [[1, 2], [7, 8]])  # pyformat: disable
Ejemplo n.º 2
0
    def test_where_select_nontrivial(self):
        """Test for where() when selecting on an intermediate axis."""
        with self.test_session(graph=tf.Graph()):
            condition = tf.constant([[True, False],
                                     [False, True]])  # pyformat: disable
            if_true = tf.constant([[[1, 1], [2, 2]],
                                   [[3, 3], [4, 4]]])  # pyformat: disable
            if_false = tf.constant([[[5, 5], [6, 6]],
                                    [[7, 7], [8, 8]]])  # pyformat: disable

            result = tensor_utils.where(condition, if_true, if_false)

            self.assertAllEqual(
                result.eval(),
                [[[1, 1], [6, 6]], [[7, 7], [4, 4]]])  # pyformat: disable