def testTensorEquals(self):
        a = ragged_factory_ops.constant([[1, 2], [3]])
        b = ragged_factory_ops.constant([[4, 5], [3]])
        c = 2
        d = ragged_factory_ops.constant([[4, 5], [3, 2, 1]])

        if tf2.enabled() and ops.executing_eagerly_outside_functions():
            # Value-based equality:
            self.assertAllEqual(math_ops.tensor_equals(a, b),
                                [[False, False], [True]])
            self.assertAllEqual(math_ops.tensor_not_equals(a, b),
                                [[True, True], [False]])

            # Value-based equality (w/ broadcasting):
            self.assertAllEqual(math_ops.tensor_equals(a, c),
                                [[False, True], [False]])
            self.assertAllEqual(math_ops.tensor_not_equals(a, c),
                                [[True, False], [True]])
            self.assertEqual(math_ops.tensor_equals(a, d),
                             False)  # not broadcast-compatible
            self.assertEqual(math_ops.tensor_not_equals(a, d),
                             True)  # not broadcast-compatible

        else:
            # Identity-based equality:
            self.assertAllEqual(math_ops.tensor_equals(a, a), True)
            self.assertAllEqual(math_ops.tensor_equals(a, b), False)
            self.assertAllEqual(math_ops.tensor_not_equals(a, b), True)
示例#2
0
 def testEqualityNone(self):
     x = constant_op.constant([1.0, 2.0, 0.0, 4.0], dtype=dtypes.float32)
     self.assertNotEqual(x, None)
     self.assertNotEqual(None, x)
     self.assertFalse(math_ops.tensor_equals(x, None))
     self.assertTrue(math_ops.tensor_not_equals(x, None))