Ejemplo n.º 1
0
    def testConvertToDTypeRaises(self, tensor_or_dtype, dtype, dtype_hint):
        if np.issctype(tensor_or_dtype):
            example_tensor = np.zeros([], tensor_or_dtype)
        elif isinstance(tensor_or_dtype, tf.DType):
            example_tensor = tf.zeros([], tensor_or_dtype)
        else:
            example_tensor = tensor_or_dtype

        with self.assertRaisesRegex(TypeError, 'Found incompatible dtypes'):
            dtype_util.convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
        with self.assertRaisesRegex(TypeError, 'Found incompatible dtypes'):
            dtype_util.convert_to_dtype(example_tensor, dtype, dtype_hint)
Ejemplo n.º 2
0
    def testConvertToDtype(self, tensor_or_dtype, dtype, dtype_hint):
        if np.issctype(tensor_or_dtype):
            example_tensor = np.zeros([], tensor_or_dtype)
        elif isinstance(tensor_or_dtype, tf.DType):
            example_tensor = tf.zeros([], tensor_or_dtype)
        else:
            example_tensor = tensor_or_dtype

        # Try with the original argument.
        self.assertEqual(
            tf.convert_to_tensor(example_tensor, dtype, dtype_hint).dtype,
            dtype_util.convert_to_dtype(tensor_or_dtype, dtype, dtype_hint))
        # Try with a concrete value.
        self.assertEqual(
            tf.convert_to_tensor(example_tensor, dtype, dtype_hint).dtype,
            dtype_util.convert_to_dtype(example_tensor, dtype, dtype_hint))
Ejemplo n.º 3
0
 def test_ones_like(self):
     x = tf1.placeholder_with_default(tf.ones([2], dtype=tf.float32),
                                      shape=None)
     self.assertEqual(dtype_util.convert_to_dtype(ps.ones_like(x)),
                      tf.float32)