def testUnbatchSingleElementTupleDataset(self): data = tuple([(math_ops.range(10), ) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) expected_types = ((dtypes.int32, ), ) * 3 data = data.batch(2) self.assertEqual(expected_types, data.output_types) data = data.apply(dataset_ops.unbatch()) self.assertEqual(expected_types, data.output_types) iterator = data.make_one_shot_iterator() op = iterator.get_next() with self.test_session() as sess: for i in range(10): self.assertEqual(((i, ), ) * 3, sess.run(op)) with self.assertRaises(errors.OutOfRangeError): sess.run(op)
def testUnbatchSingleElementTupleDataset(self): data = tuple([(math_ops.range(10),) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) expected_types = ((dtypes.int32,),) * 3 data = data.batch(2) self.assertEqual(expected_types, data.output_types) data = data.apply(dataset_ops.unbatch()) self.assertEqual(expected_types, data.output_types) iterator = data.make_one_shot_iterator() op = iterator.get_next() with self.test_session() as sess: for i in range(10): self.assertEqual(((i,),) * 3, sess.run(op)) with self.assertRaises(errors.OutOfRangeError): sess.run(op)
def testUnbatchMultiElementTupleDataset(self): data = tuple([(math_ops.range(10 * i, 10 * i + 10), array_ops.fill([10], "hi")) for i in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) expected_types = ((dtypes.int32, dtypes.string), ) * 3 data = data.batch(2) self.assertAllEqual(expected_types, data.output_types) data = data.apply(dataset_ops.unbatch()) self.assertAllEqual(expected_types, data.output_types) iterator = data.make_one_shot_iterator() op = iterator.get_next() with self.test_session() as sess: for i in range(10): self.assertEqual( ((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) with self.assertRaises(errors.OutOfRangeError): sess.run(op)
def testUnbatchMultiElementTupleDataset(self): data = tuple([(math_ops.range(10 * i, 10 * i + 10), array_ops.fill([10], "hi")) for i in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) expected_types = ((dtypes.int32, dtypes.string),) * 3 data = data.batch(2) self.assertAllEqual(expected_types, data.output_types) data = data.apply(dataset_ops.unbatch()) self.assertAllEqual(expected_types, data.output_types) iterator = data.make_one_shot_iterator() op = iterator.get_next() with self.test_session() as sess: for i in range(10): self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) with self.assertRaises(errors.OutOfRangeError): sess.run(op)