def enumerate_dataset(start=0): """A transformation that enumerate the elements of a dataset. It is Similar to python's `enumerate`. For example: ```python # NOTE: The following examples use `{ ... }` to represent the # contents of a dataset. a = { 1, 2, 3 } b = { (7, 8), (9, 10) } # The nested structure of the `datasets` argument determines the # structure of elements in the resulting dataset. a.apply(tf.contrib.data.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) } b.apply(tf.contrib.data.enumerate()) == { (0, (7, 8)), (1, (9, 10)) } ``` Args: start: A `tf.int64` scalar `tf.Tensor`, representing the start value for enumeration. Returns: A `Dataset` transformation function, which can be passed to `tf.data.Dataset.apply`. """ return enumerate_ops.enumerate_dataset(start)
def testEnumerateDataset(self): components = (["a", "b"], [1, 2], [37.0, 38]) start = constant_op.constant(20, dtype=dtypes.int64) dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( enumerate_ops.enumerate_dataset(start)) self.assertEqual(dtypes.int64, dataset.output_types[0]) self.assertEqual((), dataset.output_shapes[0]) self.assertEqual([tensor_shape.TensorShape([])] * 3, [shape for shape in dataset.output_shapes[1]]) self.assertDatasetProduces(dataset, [(20, (b"a", 1, 37.0)), (21, (b"b", 2, 38.0))])
def testEnumerateDataset(self): components = (["a", "b"], [1, 2], [37.0, 38]) start = constant_op.constant(20, dtype=dtypes.int64) iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply( enumerate_ops.enumerate_dataset(start)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() self.assertEqual(dtypes.int64, get_next[0].dtype) self.assertEqual((), get_next[0].shape) self.assertEqual([tensor_shape.TensorShape([])] * 3, [t.shape for t in get_next[1]]) with self.cached_session() as sess: self.evaluate(init_op) self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next)) self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next)