def testConvertValue(self, value, value_type, expected=None): if callable(value): value = value() # deferred construction (contains tensor) if expected is None: expected = value converted = extension_type_field._convert_value(value, value_type, ('x',)) if isinstance(converted, (ops.Tensor, ragged_tensor.RaggedTensor)): self.assertAllEqual(converted, expected) else: self.assertEqual(converted, expected)
def testConvertValueError(self, value, value_type, error): if callable(value): value = value() # deferred construction (contains tensor) with self.assertRaisesRegex(TypeError, error): extension_type_field._convert_value(value, value_type, ('x',))