Ejemplo n.º 1
0
def enumerate_dataset(start=0):
    """A transformation that enumerate the elements of a dataset.

  It is Similar to python's `enumerate`.
  For example:

  ```python
  # NOTE: The following examples use `{ ... }` to represent the
  # contents of a dataset.
  a = { 1, 2, 3 }
  b = { (7, 8), (9, 10) }

  # The nested structure of the `datasets` argument determines the
  # structure of elements in the resulting dataset.
  a.apply(tf.contrib.data.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
  b.apply(tf.contrib.data.enumerate()) == { (0, (7, 8)), (1, (9, 10)) }
  ```

  Args:
    start: A `tf.int64` scalar `tf.Tensor`, representing the start
      value for enumeration.

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
    return enumerate_ops.enumerate_dataset(start)
Ejemplo n.º 2
0
def enumerate_dataset(start=0):
  """A transformation that enumerate the elements of a dataset.

  It is Similar to python's `enumerate`.
  For example:

  ```python
  # NOTE: The following examples use `{ ... }` to represent the
  # contents of a dataset.
  a = { 1, 2, 3 }
  b = { (7, 8), (9, 10) }

  # The nested structure of the `datasets` argument determines the
  # structure of elements in the resulting dataset.
  a.apply(tf.contrib.data.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
  b.apply(tf.contrib.data.enumerate()) == { (0, (7, 8)), (1, (9, 10)) }
  ```

  Args:
    start: A `tf.int64` scalar `tf.Tensor`, representing the start
      value for enumeration.

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
  return enumerate_ops.enumerate_dataset(start)
  def testEnumerateDataset(self):
    components = (["a", "b"], [1, 2], [37.0, 38])
    start = constant_op.constant(20, dtype=dtypes.int64)

    dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
        enumerate_ops.enumerate_dataset(start))

    self.assertEqual(dtypes.int64, dataset.output_types[0])
    self.assertEqual((), dataset.output_shapes[0])
    self.assertEqual([tensor_shape.TensorShape([])] * 3,
                     [shape for shape in dataset.output_shapes[1]])

    self.assertDatasetProduces(dataset, [(20, (b"a", 1, 37.0)),
                                         (21, (b"b", 2, 38.0))])
Ejemplo n.º 4
0
  def testEnumerateDataset(self):
    components = (["a", "b"], [1, 2], [37.0, 38])
    start = constant_op.constant(20, dtype=dtypes.int64)

    iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply(
        enumerate_ops.enumerate_dataset(start)).make_initializable_iterator())
    init_op = iterator.initializer
    get_next = iterator.get_next()

    self.assertEqual(dtypes.int64, get_next[0].dtype)
    self.assertEqual((), get_next[0].shape)
    self.assertEqual([tensor_shape.TensorShape([])] * 3,
                     [t.shape for t in get_next[1]])

    with self.cached_session() as sess:
      self.evaluate(init_op)
      self.assertEqual((20, (b"a", 1, 37.0)), self.evaluate(get_next))
      self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)