def test_make_dataset_from_variant_tensor_constructs_dataset(self): with tf.Graph().as_default(): ds = tensorflow_utils.make_dataset_from_variant_tensor( tf.data.experimental.to_variant(tf.data.Dataset.range(5)), tf.int64) self.assertIsInstance(ds, tf.data.Dataset) result = ds.reduce(np.int64(0), lambda x, y: x + y) with tf.compat.v1.Session() as sess: self.assertEqual(sess.run(result), 10)
def test_make_dataset_from_variant_tensor_fails_with_bad_type(self): with self.assertRaises(TypeError): with tf.Graph().as_default(): tensorflow_utils.make_dataset_from_variant_tensor( tf.data.experimental.to_variant(tf.data.Dataset.range(5)), 'a')
def test_make_dataset_from_variant_tensor_fails_with_bad_tensor(self): with self.assertRaises(TypeError): with tf.Graph().as_default(): tensorflow_utils.make_dataset_from_variant_tensor( tf.constant(10), tf.int32)