Beispiel #1
0
  def _make_key_func(self, key_func, input_dataset):
    """Make wrapping Defun for key_func."""

    @function.Defun(
        *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types)))
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
      if ret.dtype != dtypes.int64:
        raise ValueError("`key_func` must return a single tf.int64 tensor.")
      return ret

    self._key_func = tf_key_func
    self._key_func.add_to_graph(ops.get_default_graph())
Beispiel #2
0
    def _make_key_func(self, key_func, input_dataset):
        """Make wrapping Defun for key_func."""
        @function.Defun(*nest.flatten(
            sparse.unwrap_sparse_types(input_dataset.output_types)))
        def tf_key_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            # Pass in shape information from the input_dataset.
            for arg, shape in zip(args,
                                  nest.flatten(input_dataset.output_shapes)):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(input_dataset.output_types,
                                                args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, input_dataset.output_types)
            # pylint: disable=protected-access
            if dataset_ops._should_unpack_args(nested_args):
                ret = key_func(*nested_args)
            # pylint: enable=protected-access
            else:
                ret = key_func(nested_args)
            ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
            if ret.dtype != dtypes.int64:
                raise ValueError(
                    "`key_func` must return a single tf.int64 tensor.")
            return ret

        self._key_func = tf_key_func
        self._key_func.add_to_graph(ops.get_default_graph())
Beispiel #3
0
    def from_string_handle(string_handle, output_types, output_shapes=None):
        """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a @{tf.Session.run} call.
    In that case, `string_handle` would a @{tf.placeholder}, and you would feed
    it with the value of @{tf.data.Iterator.string_handle} in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`)
        objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`)
        component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.

    Returns:
      An `Iterator`.
    """
        output_types = nest.map_structure(dtypes.as_dtype, output_types)
        if output_shapes is None:
            output_shapes = nest.map_structure(
                lambda _: tensor_shape.TensorShape(None), output_types)
        else:
            output_shapes = nest.map_structure_up_to(output_types,
                                                     tensor_shape.as_shape,
                                                     output_shapes)
        nest.assert_same_structure(output_types, output_shapes)
        string_handle = ops.convert_to_tensor(string_handle,
                                              dtype=dtypes.string)
        iterator_resource = gen_dataset_ops.iterator_from_string_handle(
            string_handle,
            output_types=nest.flatten(
                sparse.unwrap_sparse_types(output_types)),
            output_shapes=nest.flatten(output_shapes))
        return Iterator(iterator_resource, None, output_types, output_shapes)
Beispiel #4
0
  def from_string_handle(string_handle, output_types, output_shapes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a @{tf.Session.run} call.
    In that case, `string_handle` would a @{tf.placeholder}, and you would feed
    it with the value of @{tf.data.Iterator.string_handle} in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`)
        objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`)
        component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    nest.assert_same_structure(output_types, output_shapes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    iterator_resource = gen_dataset_ops.iterator_from_string_handle(
        string_handle,
        output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)),
        output_shapes=nest.flatten(output_shapes))
    return Iterator(iterator_resource, None, output_types, output_shapes)
Beispiel #5
0
 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return gen_dataset_ops.scan_dataset(
       input_t,
       nest.flatten(self._initial_state),
       self._scan_func.captured_inputs,
       f=self._scan_func,
       output_types=nest.flatten(
           sparse.unwrap_sparse_types(self.output_types)),
       output_shapes=nest.flatten(self.output_shapes))
Beispiel #6
0
 def _as_variant_tensor(self):
     return gen_dataset_ops.parallel_interleave_dataset(
         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
         self._map_func.captured_inputs,
         self._cycle_length,
         self._block_length,
         self._sloppy,
         f=self._map_func,
         output_types=nest.flatten(
             sparse.unwrap_sparse_types(self.output_types)),
         output_shapes=nest.flatten(self.output_shapes))
Beispiel #7
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.parallel_interleave_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._map_func.captured_inputs,
       self._cycle_length,
       self._block_length,
       self._sloppy,
       f=self._map_func,
       output_types=nest.flatten(
           sparse.unwrap_sparse_types(self.output_types)),
       output_shapes=nest.flatten(self.output_shapes))
Beispiel #8
0
 def _as_variant_tensor(self):
     return gen_dataset_ops.group_by_window_dataset(
         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
         self._key_func.captured_inputs,
         self._reduce_func.captured_inputs,
         self._window_size_func.captured_inputs,
         key_func=self._key_func,
         reduce_func=self._reduce_func,
         window_size_func=self._window_size_func,
         output_types=nest.flatten(
             sparse.unwrap_sparse_types(self.output_types)),
         output_shapes=nest.flatten(self.output_shapes))
Beispiel #9
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   input_resource = self._input_dataset._as_variant_tensor()
   return gen_dataset_ops.map_and_batch_dataset(
       input_resource,
       self._map_func.captured_inputs,
       f=self._map_func,
       batch_size=self._batch_size,
       num_parallel_batches=self._num_parallel_batches,
       output_types=nest.flatten(
           sparse.unwrap_sparse_types(self.output_types)),
       output_shapes=nest.flatten(self.output_shapes))
Beispiel #10
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.group_by_window_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._key_func.captured_inputs,
       self._reduce_func.captured_inputs,
       self._window_size_func.captured_inputs,
       key_func=self._key_func,
       reduce_func=self._reduce_func,
       window_size_func=self._window_size_func,
       output_types=nest.flatten(
           sparse.unwrap_sparse_types(self.output_types)),
       output_shapes=nest.flatten(self.output_shapes))
Beispiel #11
0
    def get_next(self, name=None):
        """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
        return sparse.deserialize_sparse_tensors(
            nest.pack_sequence_as(
                self._output_types,
                gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=nest.flatten(
                        sparse.unwrap_sparse_types(self._output_types)),
                    output_shapes=nest.flatten(self._output_shapes),
                    name=name)), self._output_types)
Beispiel #12
0
    def __init__(self, input_dataset, map_func, cycle_length, block_length,
                 sloppy):
        """See `tf.contrib.data.parallel_interleave()` for details."""
        super(ParallelInterleaveDataset, self).__init__()
        self._input_dataset = input_dataset

        @function.Defun(*nest.flatten(
            sparse.unwrap_sparse_types(input_dataset.output_types)))
        def tf_map_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            # Pass in shape information from the input_dataset.
            for arg, shape in zip(args,
                                  nest.flatten(input_dataset.output_shapes)):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(input_dataset.output_types,
                                                args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, input_dataset.output_types)
            if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
                dataset = map_func(*nested_args)
            else:
                dataset = map_func(nested_args)

            if not isinstance(dataset, dataset_ops.Dataset):
                raise TypeError("`map_func` must return a `Dataset` object.")

            self._output_types = dataset.output_types
            self._output_shapes = dataset.output_shapes

            return dataset._as_variant_tensor()  # pylint: disable=protected-access

        self._map_func = tf_map_func
        self._map_func.add_to_graph(ops.get_default_graph())

        self._cycle_length = ops.convert_to_tensor(cycle_length,
                                                   dtype=dtypes.int64,
                                                   name="cycle_length")
        self._block_length = ops.convert_to_tensor(block_length,
                                                   dtype=dtypes.int64,
                                                   name="block_length")
        self._sloppy = ops.convert_to_tensor(sloppy,
                                             dtype=dtypes.bool,
                                             name="sloppy")
Beispiel #13
0
 def testUnwrapSparseTypes(self):
   d = dtypes.string
   t = sparse.SparseType(dtypes.int32)
   test_cases = (
       ((), ()),
       (t, d),
       (d, d),
       ((t), (d)),
       ((d), (d)),
       ((t, ()), (d, ())),
       (((), t), ((), d)),
       ((d, ()), (d, ())),
       (((), d), ((), d)),
       ((t, (), d), (d, (), d)),
       (((), t, ()), ((), d, ())),
       (((), d, ()), ((), d, ())),
   )
   for test_case in test_cases:
     self.assertEqual(sparse.unwrap_sparse_types(test_case[0]), test_case[1])
Beispiel #14
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.unwrap_sparse_types(
                                          self._output_types)),
                                  output_shapes=nest.flatten(
                                      self._output_shapes),
                                  name=name)), self._output_types)
Beispiel #15
0
 def testUnwrapSparseTypes(self):
     d = dtypes.string
     t = sparse.SparseType(dtypes.int32)
     test_cases = (
         ((), ()),
         (t, d),
         (d, d),
         ((t), (d)),
         ((d), (d)),
         ((t, ()), (d, ())),
         (((), t), ((), d)),
         ((d, ()), (d, ())),
         (((), d), ((), d)),
         ((t, (), d), (d, (), d)),
         (((), t, ()), ((), d, ())),
         (((), d, ()), ((), d, ())),
     )
     for test_case in test_cases:
         self.assertEqual(sparse.unwrap_sparse_types(test_case[0]),
                          test_case[1])
Beispiel #16
0
  def __init__(self, input_dataset, map_func, cycle_length, block_length,
               sloppy):
    """See `tf.contrib.data.parallel_interleave()` for details."""
    super(ParallelInterleaveDataset, self).__init__()
    self._input_dataset = input_dataset

    @function.Defun(
        *nest.flatten(sparse.unwrap_sparse_types(input_dataset.output_types)))
    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types)
      if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset._as_variant_tensor()  # pylint: disable=protected-access

    self._map_func = tf_map_func
    self._map_func.add_to_graph(ops.get_default_graph())

    self._cycle_length = ops.convert_to_tensor(
        cycle_length, dtype=dtypes.int64, name="cycle_length")
    self._block_length = ops.convert_to_tensor(
        block_length, dtype=dtypes.int64, name="block_length")
    self._sloppy = ops.convert_to_tensor(
        sloppy, dtype=dtypes.bool, name="sloppy")
Beispiel #17
0
 def _as_variant_tensor(self):
     return gen_dataset_ops.ignore_errors_dataset(
         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
         output_shapes=nest.flatten(self.output_shapes),
         output_types=nest.flatten(
             sparse.unwrap_sparse_types(self.output_types)))
Beispiel #18
0
 def _as_variant_tensor(self):
   return gen_dataset_ops.ignore_errors_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       output_shapes=nest.flatten(self.output_shapes),
       output_types=nest.flatten(
           sparse.unwrap_sparse_types(self.output_types)))
Beispiel #19
0
  def __init__(self, input_dataset, initial_state, scan_func):
    """See `scan()` for details."""
    super(_ScanDataset, self).__init__()
    self._input_dataset = input_dataset

    with ops.name_scope("initial_state"):
      self._initial_state = nest.pack_sequence_as(initial_state, [
          ops.convert_to_tensor(t, name="component_%d" % i)
          for i, t in enumerate(nest.flatten(initial_state))
      ])

    # Compute initial values for the state shapes and types based on
    # the initial state. These will be refined by running
    # `tf_scan_func` one or more times below.
    # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor.
    self._state_shapes = nest.pack_sequence_as(
        self._initial_state,
        [t.shape for t in nest.flatten(self._initial_state)])
    self._state_types = nest.pack_sequence_as(
        self._initial_state,
        [t.dtype for t in nest.flatten(self._initial_state)])

    # Will be populated by calling `tf_scan_func`.
    self._output_shapes = None
    self._output_types = None

    # Iteratively rerun the scan function until reaching a fixed pont on
    # `self._state_shapes`.
    need_to_rerun = True
    while need_to_rerun:

      flat_state_shapes = nest.flatten(self._state_shapes)
      flat_state_types = nest.flatten(self._state_types)

      # Create a list in which `tf_scan_func` will store the s
      flat_new_state_shapes = []

      @function.Defun(*(flat_state_types + nest.flatten(
          sparse.unwrap_sparse_types(input_dataset.output_types))))
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        for arg, shape in zip(
            args,
            flat_state_shapes + nest.flatten(input_dataset.output_shapes)):
          arg.set_shape(shape)

        pivot = len(flat_state_shapes)
        old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
        input_value = nest.pack_sequence_as(input_dataset.output_types,
                                            args[pivot:])

        ret = scan_func(old_state, input_value)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")
        new_state, output_value = ret

        flat_new_state = [
            ops.convert_to_tensor(t) for t in nest.flatten(new_state)
        ]
        flat_output_value = [
            ops.convert_to_tensor(t) for t in nest.flatten(output_value)
        ]

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend([t.shape for t in flat_new_state])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.shape for t in flat_output_value])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, flat_state_types):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types, nest.pack_sequence_as(
                    self._state_types, [t.dtype for t in flat_new_state])))
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in flat_output_value])

        return flat_new_state + flat_output_value

      # Use the private method that will execute `tf_scan_func` but delay
      # adding it to the graph in case we need to rerun the function.
      tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access

      weakened_state_shapes = [
          original.most_specific_compatible_shape(new)
          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
      ]

      need_to_rerun = False
      for original_shape, weakened_shape in zip(flat_state_shapes,
                                                weakened_state_shapes):
        if original_shape.ndims is not None and (
            weakened_shape.ndims is None or
            original_shape.as_list() != weakened_shape.as_list()):
          need_to_rerun = True
          break

      if need_to_rerun:
        # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
        # `tf_scan_func`.
        self._state_shapes = nest.pack_sequence_as(self._state_shapes,
                                                   weakened_state_shapes)

    self._scan_func = tf_scan_func
Beispiel #20
0
    def from_structure(output_types, output_shapes=None, shared_name=None):
        """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`)
        objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`)
        component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
        output_types = nest.map_structure(dtypes.as_dtype, output_types)
        if output_shapes is None:
            output_shapes = nest.map_structure(
                lambda _: tensor_shape.TensorShape(None), output_types)
        else:
            output_shapes = nest.map_structure_up_to(output_types,
                                                     tensor_shape.as_shape,
                                                     output_shapes)
        nest.assert_same_structure(output_types, output_shapes)
        if shared_name is None:
            shared_name = ""
        iterator_resource = gen_dataset_ops.iterator(
            container="",
            shared_name=shared_name,
            output_types=nest.flatten(
                sparse.unwrap_sparse_types(output_types)),
            output_shapes=nest.flatten(output_shapes))
        return Iterator(iterator_resource, None, output_types, output_shapes)
Beispiel #21
0
  def from_structure(output_types, output_shapes=None, shared_name=None):
    """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`)
        objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`)
        component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    nest.assert_same_structure(output_types, output_shapes)
    if shared_name is None:
      shared_name = ""
    iterator_resource = gen_dataset_ops.iterator(
        container="",
        shared_name=shared_name,
        output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)),
        output_shapes=nest.flatten(output_shapes))
    return Iterator(iterator_resource, None, output_types, output_shapes)