示例#1
0
            def tf_reduce_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = reduce_func(nested_state_args, nested_input_args)

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])

                # Extract shape information from the returned values.
                flat_new_state = nest.flatten(ret)
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in flat_new_state])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(flat_new_state,
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in flat_new_state])))

                # Serialize any sparse tensors.
                ret = nest.pack_sequence_as(ret, [
                    t
                    for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
                ])
                return nest.flatten(ret)
示例#2
0
      def tf_reduce_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = reduce_func(nested_state_args, nested_input_args)

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])

        # Extract shape information from the returned values.
        flat_new_state = nest.flatten(ret)
        flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(self._state_types,
                                       [t.dtype for t in flat_new_state])))

        dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        ret = nest.pack_sequence_as(
            ret,
            [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
        return nest.flatten(ret)
示例#3
0
    def _next_internal(self):
        """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
        if not context.executing_eagerly():
            with ops.device(self._device):
                ret = gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)
            return self._structure._from_compatible_tensor_list(ret)  # pylint: disable=protected-access

        # This runs in sync mode as iterators use an error status to communicate
        # that there is no more data to iterate over.
        # TODO(b/77291417): Fix
        with context.execution_mode(context.SYNC):
            with ops.device(self._device):
                # TODO(ashankar): Consider removing this ops.device() contextmanager
                # and instead mimic ops placement in graphs: Operations on resource
                # handles execute on the same device as where the resource is placed.
                # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
                # because in eager mode this code will run synchronously on the calling
                # thread. Therefore we do not need to make a defensive context switch
                # to a background thread, and can achieve a small constant performance
                # boost by invoking the iterator synchronously.
                ret = gen_dataset_ops.iterator_get_next_sync(
                    self._iterator_resource,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)

            return sparse.deserialize_sparse_tensors(
                nest.pack_sequence_as(self._output_types, ret),
                self._output_types, self._output_shapes, self._output_classes)
示例#4
0
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
示例#5
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)
示例#6
0
 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
示例#7
0
  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    with ops.device(self._device):
      if self._buffer_resource_handle is not None:
        ret = prefetching_ops.function_buffering_resource_get_next(
            function_buffer_resource=self._buffer_resource_handle,
            output_types=self._flat_output_types)
      else:
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
        # because in eager mode this code will run synchronously on the calling
        # thread. Therefore we do not need to make a defensive context switch
        # to a background thread, and can achieve a small constant performance
        # boost by invoking the iterator synchronously.
        ret = gen_dataset_ops.iterator_get_next_sync(
            self._resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types, ret), self._output_types,
        self._output_shapes, self._output_classes)
示例#8
0
  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    with ops.device(self._device):
      if self._buffer_resource_handle is not None:
        ret = prefetching_ops.function_buffering_resource_get_next(
            function_buffer_resource=self._buffer_resource_handle,
            output_types=self._flat_output_types)
      else:
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
        # because in eager mode this code will run synchronously on the calling
        # thread. Therefore we do not need to make a defensive context switch
        # to a background thread, and can achieve a small constant performance
        # boost by invoking the iterator synchronously.
        ret = gen_dataset_ops.iterator_get_next_sync(
            self._resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types, ret), self._output_types,
        self._output_shapes, self._output_classes)
示例#9
0
 def testSerializeManyDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((),
        sparse_tensor.SparseTensor(
            indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_many_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
    def get_next(self, name=None):
        """See `tf.data.Iterator.get_next`."""
        self._get_next_call_count += 1
        if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
            warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

        flat_result = []
        # TODO(priyag): This will fail if the input size (typically number of
        # batches) is not divisible by number of devices.
        # How do we handle that more gracefully / let the user know?
        for buffer_resource in self._buffering_resources:
            flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
                buffer_resource,
                output_types=data_nest.flatten(
                    sparse.as_dense_types(self.output_types,
                                          self.output_classes)),
                name=name)

            ret = sparse.deserialize_sparse_tensors(
                data_nest.pack_sequence_as(self.output_types, flat_ret),
                self.output_types, self.output_shapes, self.output_classes)

            for tensor, shape in zip(data_nest.flatten(ret),
                                     data_nest.flatten(self.output_shapes)):
                if isinstance(tensor, ops.Tensor):
                    tensor.set_shape(shape)
            flat_result.append(ret)

        return nest.pack_sequence_as(self._devices, flat_result)
示例#11
0
        def tf_key_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            # Pass in shape information from the input_dataset.
            dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                                  input_dataset.output_classes)
            for arg, shape in zip(args, nest.flatten(dense_shapes)):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(input_dataset.output_types,
                                                args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, input_dataset.output_types,
                input_dataset.output_shapes, input_dataset.output_classes)
            # pylint: disable=protected-access
            if dataset_ops._should_unpack_args(nested_args):
                ret = key_func(*nested_args)
            # pylint: enable=protected-access
            else:
                ret = key_func(nested_args)
            ret = ops.convert_to_tensor(ret)
            if ret.dtype != dtypes.int64 or ret.get_shape(
            ) != tensor_shape.scalar():
                raise ValueError(
                    "`key_func` must return a single tf.int64 tensor. "
                    "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
            return ret
示例#12
0
  def get_next(self, name=None):
    """See `tf.data.Iterator.get_next`."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_result = []
    # TODO(priyag): This will fail if the input size (typically number of
    # batches) is not divisible by number of devices.
    # How do we handle that more gracefully / let the user know?
    for buffer_resource in self._buffering_resources:
      flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
          buffer_resource,
          output_types=data_nest.flatten(sparse.as_dense_types(
              self.output_types, self.output_classes)), name=name)

      ret = sparse.deserialize_sparse_tensors(
          data_nest.pack_sequence_as(self.output_types, flat_ret),
          self.output_types, self.output_shapes, self.output_classes)

      for tensor, shape in zip(
          data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
        if isinstance(tensor, ops.Tensor):
          tensor.set_shape(shape)
      flat_result.append(ret)

    return nest.pack_sequence_as(self._devices, flat_result)
示例#13
0
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
示例#14
0
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret)
      if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
        raise ValueError(
            "`key_func` must return a single tf.int64 tensor. "
            "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access
      return ret
示例#15
0
    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset._as_variant_tensor()  # pylint: disable=protected-access
示例#16
0
 def testSerializeDeserialize(self):
     test_cases = (
         (),
         sparse_tensor.SparseTensor(indices=[[0, 0]],
                                    values=[1],
                                    dense_shape=[1, 1]),
         sparse_tensor.SparseTensor(indices=[[3, 4]],
                                    values=[-1],
                                    dense_shape=[4, 5]),
         sparse_tensor.SparseTensor(indices=[[0, 0], [3, 4]],
                                    values=[1, -1],
                                    dense_shape=[4, 5]),
         (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                     values=[1],
                                     dense_shape=[1, 1])),
         (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                     values=[1],
                                     dense_shape=[1, 1]), ()),
         ((),
          sparse_tensor.SparseTensor(indices=[[0, 0]],
                                     values=[1],
                                     dense_shape=[1, 1])),
     )
     for expected in test_cases:
         actual = sparse.deserialize_sparse_tensors(
             sparse.serialize_sparse_tensors(expected),
             sparse.get_sparse_types(expected))
         nest.assert_same_structure(expected, actual)
         for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
             self.assertSparseValuesEqual(a, e)
示例#17
0
        def tf_key_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            # Pass in shape information from the input_dataset.
            dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                                  input_dataset.output_classes)
            for arg, shape in zip(args, nest.flatten(dense_shapes)):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(input_dataset.output_types,
                                                args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, input_dataset.output_types,
                input_dataset.output_shapes, input_dataset.output_classes)
            # pylint: disable=protected-access
            if dataset_ops._should_unpack_args(nested_args):
                ret = key_func(*nested_args)
            # pylint: enable=protected-access
            else:
                ret = key_func(nested_args)
            ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
            if ret.dtype != dtypes.int64:
                raise ValueError(
                    "`key_func` must return a single tf.int64 tensor.")
            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_window()")  # pylint: disable=protected-access
            return ret
示例#18
0
    def get_next(self, name=None):
        """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
        self._get_next_call_count += 1
        if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
            warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

        return sparse.deserialize_sparse_tensors(
            nest.pack_sequence_as(
                self._output_types,
                gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=nest.flatten(
                        sparse.as_dense_types(self._output_types,
                                              self._output_classes)),
                    output_shapes=nest.flatten(
                        sparse.as_dense_shapes(self._output_shapes,
                                               self._output_classes)),
                    name=name)), self._output_types, self._output_shapes,
            self._output_classes)
示例#19
0
    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
示例#20
0
 def testSerializeManyDeserialize(self, input_fn):
     test_case = input_fn()
     classes = sparse.get_classes(test_case)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_many_sparse_tensors(test_case), types, shapes,
         sparse.get_classes(test_case))
     nest.assert_same_structure(test_case, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(test_case)):
         self.assertSparseValuesEqual(a, e)
示例#21
0
 def _next_internal(self):
   """Returns a nested structure of `tf.Tensor`s containing the next element.
   """
   if self._buffer_resource_handle is not None:
     with ops.device(self._device):
       ret = prefetching_ops.function_buffering_resource_get_next(
           function_buffer_resource=self._buffer_resource_handle,
           output_types=self._flat_output_types)
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(self._output_types, ret), self._output_types,
         self._output_shapes, self._output_classes)
   else:
     return super(Iterator, self)._next_internal()
示例#22
0
 def _next_internal(self):
     """Returns a nested structure of `tf.Tensor`s containing the next element.
 """
     if self._buffer_resource_handle is not None:
         with ops.device(self._device):
             ret = prefetching_ops.function_buffering_resource_get_next(
                 function_buffer_resource=self._buffer_resource_handle,
                 output_types=self._flat_output_types)
         return sparse.deserialize_sparse_tensors(
             nest.pack_sequence_as(self._output_types, ret),
             self._output_types, self._output_shapes, self._output_classes)
     else:
         return super(Iterator, self)._next_internal()
示例#23
0
 def _next_internal(self):
   """Returns a nested structure of `tf.Tensor`s containing the next element.
   """
   # This runs in sync mode as iterators use an error status to communicate
   # that there is no more data to iterate over.
   # TODO(b/77291417): Fix
   with context.execution_mode(context.SYNC):
     with ops.device(self._device):
       ret = ged_ops.experimental_function_buffering_resource_get_next(
           function_buffer_resource=self._buffering_resource,
           output_types=self._flat_output_types)
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(self._output_types, ret), self._output_types,
         self._output_shapes, self._output_classes)
示例#24
0
 def _next_internal(self):
   """Returns a nested structure of `tf.Tensor`s containing the next element.
   """
   # This runs in sync mode as iterators use an error status to communicate
   # that there is no more data to iterate over.
   # TODO(b/77291417): Fix
   with context.execution_mode(context.SYNC):
     with ops.device(self._device):
       ret = gen_dataset_ops.function_buffering_resource_get_next(
           function_buffer_resource=self._buffering_resource,
           output_types=self._flat_output_types)
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(self._output_types, ret), self._output_types,
         self._output_shapes, self._output_classes)
示例#25
0
def get_single_element(dataset):
  """Returns the single element in `dataset` as a nested structure of tensors.

  This function enables you to use a @{tf.data.Dataset} in a stateless
  "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}.
  This can be useful when your preprocessing transformations are expressed
  as a `Dataset`, and you want to use the transformation at serving time.
  For example:

  ```python
  input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])

  def preprocessing_fn(input_str):
    # ...
    return image, label

  dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
             .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
             .batch(BATCH_SIZE))

  image_batch, label_batch = tf.contrib.data.get_single_element(dataset)
  ```

  Args:
    dataset: A @{tf.data.Dataset} object containing a single element.

  Returns:
    A nested structure of @{tf.Tensor} objects, corresponding to the single
    element of `dataset`.

  Raises:
    TypeError: if `dataset` is not a `tf.data.Dataset` object.
    InvalidArgumentError (at runtime): if `dataset` does not contain exactly
      one element.
  """
  if not isinstance(dataset, dataset_ops.Dataset):
    raise TypeError("`dataset` must be a `tf.data.Dataset` object.")

  nested_ret = nest.pack_sequence_as(
      dataset.output_types, gen_dataset_ops.dataset_to_single_element(
          dataset._as_variant_tensor(),  # pylint: disable=protected-access
          output_types=nest.flatten(sparse.as_dense_types(
              dataset.output_types, dataset.output_classes)),
          output_shapes=nest.flatten(sparse.as_dense_shapes(
              dataset.output_shapes, dataset.output_classes))))
  return sparse.deserialize_sparse_tensors(
      nested_ret, dataset.output_types, dataset.output_shapes,
      dataset.output_classes)
示例#26
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
示例#27
0
    def get_next(self, name=None):
        """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
        return sparse.deserialize_sparse_tensors(
            nest.pack_sequence_as(
                self._output_types,
                gen_dataset_ops.iterator_get_next(
                    self._iterator_resource,
                    output_types=nest.flatten(
                        sparse.unwrap_sparse_types(self._output_types)),
                    output_shapes=nest.flatten(self._output_shapes),
                    name=name)), self._output_types)
示例#28
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
示例#29
0
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
      if ret.dtype != dtypes.int64:
        raise ValueError("`key_func` must return a single tf.int64 tensor.")
      return ret
示例#30
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.unwrap_sparse_types(
                                          self._output_types)),
                                  output_shapes=nest.flatten(
                                      self._output_shapes),
                                  name=name)), self._output_types)
示例#31
0
 def _next_internal(self):
     """Returns a nested structure of `tf.Tensor`s containing the next element.
 """
     # This runs in sync mode as iterators use an error status to communicate
     # that there is no more data to iterate over.
     # TODO (b/77291417): Fix id:669
     # https://github.com/imdone/tensorflow/issues/670
     with context.execution_mode(context.SYNC):
         if self._buffer_resource_handle is not None:
             with ops.device(self._device):
                 ret = prefetching_ops.function_buffering_resource_get_next(
                     function_buffer_resource=self._buffer_resource_handle,
                     output_types=self._flat_output_types)
             return sparse.deserialize_sparse_tensors(
                 nest.pack_sequence_as(self._output_types,
                                       ret), self._output_types,
                 self._output_shapes, self._output_classes)
         else:
             return super(Iterator, self)._next_internal()
示例#32
0
  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    with ops.device(self._device):
      if self._buffer_resource_handle is not None:
        ret = prefetching_ops.function_buffering_resource_get_next(
            function_buffer_resource=self._buffer_resource_handle,
            output_types=self._flat_output_types)
      else:
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        ret = gen_dataset_ops.iterator_get_next(
            self._resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types, ret), self._output_types,
        self._output_shapes, self._output_classes)
示例#33
0
    def _next_internal(self):
        """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
        with ops.device(self._device):
            if self._buffer_resource_handle is not None:
                ret = prefetching_ops.function_buffering_resource_get_next(
                    function_buffer_resource=self._buffer_resource_handle,
                    output_types=self._flat_output_types)
            else:
                # TODO(ashankar): Consider removing this ops.device() contextmanager
                # and instead mimic ops placement in graphs: Operations on resource
                # handles execute on the same device as where the resource is placed.
                ret = gen_dataset_ops.iterator_get_next(
                    self._resource,
                    output_types=self._flat_output_types,
                    output_shapes=self._flat_output_shapes)

        return sparse.deserialize_sparse_tensors(
            nest.pack_sequence_as(self._output_types, ret), self._output_types,
            self._output_shapes, self._output_classes)
示例#34
0
  def get_next(self, name=None):
    """See @{tf.data.Iterator.get_next}."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
        self._buffering_resource,
        output_types=nest.flatten(sparse.as_dense_types(
            self.output_types, self.output_classes)), name=name)

    ret = sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self.output_types, flat_ret),
        self.output_types, self.output_shapes, self.output_classes)

    for tensor, shape in zip(
        nest.flatten(ret), nest.flatten(self.output_shapes)):
      if isinstance(tensor, ops.Tensor):
        tensor.set_shape(shape)

    return ret
示例#35
0
        def tf_finalize_func(*args):
            """A wrapper for Defun that facilitates shape inference."""
            for arg, shape in zip(
                    args,
                    nest.flatten(
                        sparse.as_dense_shapes(self._state_shapes,
                                               self._state_classes))):
                arg.set_shape(shape)

            nested_args = nest.pack_sequence_as(self._state_types, args)
            nested_args = sparse.deserialize_sparse_tensors(
                nested_args, self._state_types, self._state_shapes,
                self._state_classes)

            ret = finalize_func(nested_args)

            # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
            # values to tensors.
            ret = nest.pack_sequence_as(ret, [
                sparse_tensor.SparseTensor.from_value(t)
                if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                for t in nest.flatten(ret)
            ])

            self._output_classes = sparse.get_classes(ret)
            self._output_shapes = nest.pack_sequence_as(
                ret, [t.get_shape() for t in nest.flatten(ret)])
            self._output_types = nest.pack_sequence_as(
                ret, [t.dtype for t in nest.flatten(ret)])

            dataset_ops._warn_if_collections(
                "tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

            # Serialize any sparse tensors.
            ret = nest.pack_sequence_as(ret, [
                t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
            ])
            return nest.flatten(ret)
示例#36
0
 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected),
         sparse.get_sparse_types(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
示例#37
0
  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    # This runs in sync mode as iterators use an error status to communicate
    # that there is no more data to iterate over.
    # TODO(b/77291417): Fix
    with context.execution_mode(context.SYNC):
      with ops.device(self._device):
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
        # because in eager mode this code will run synchronously on the calling
        # thread. Therefore we do not need to make a defensive context switch
        # to a background thread, and can achieve a small constant performance
        # boost by invoking the iterator synchronously.
        ret = gen_dataset_ops.iterator_get_next_sync(
            self._resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

      return sparse.deserialize_sparse_tensors(
          nest.pack_sequence_as(self._output_types, ret), self._output_types,
          self._output_shapes, self._output_classes)
示例#38
0
      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        print(self._state_classes)
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        print(input_dataset.output_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = scan_func(nested_state_args, nested_input_args)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])
        new_state, output_value = ret

        # Extract and validate class information from the returned values.
        for t, clazz in zip(
            nest.flatten(new_state), nest.flatten(self._state_classes)):
          if not isinstance(t, clazz):
            raise TypeError(
                "The element classes for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_classes,
                 nest.pack_sequence_as(
                     self._state_types,
                     [type(t) for t in nest.flatten(new_state)])))
        self._output_classes = sparse.get_classes(output_value)

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend(
            [t.get_shape() for t in nest.flatten(new_state)])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.get_shape() for t in nest.flatten(output_value)])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(
            nest.flatten(new_state), nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(
                     self._state_types,
                     [t.dtype for t in nest.flatten(new_state)])))
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in nest.flatten(output_value)])

        dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        new_state = nest.pack_sequence_as(new_state, [
            t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
        ])
        output_value = nest.pack_sequence_as(output_value, [
            t for t in nest.flatten(
                sparse.serialize_sparse_tensors(output_value))
        ])
        return nest.flatten(new_state) + nest.flatten(output_value)
示例#39
0
            def tf_scan_func(*args):
                """A wrapper for Defun that facilitates shape inference."""
                # Pass in shape information from the state and input_dataset.
                for arg, shape in zip(
                        args,
                        nest.flatten(
                            sparse.as_dense_shapes(self._state_shapes,
                                                   self._state_classes)) +
                        nest.flatten(
                            sparse.as_dense_shapes(
                                input_dataset.output_shapes,
                                input_dataset.output_classes))):
                    arg.set_shape(shape)

                pivot = len(nest.flatten(self._state_shapes))
                print(self._state_classes)
                nested_state_args = nest.pack_sequence_as(
                    self._state_types, args[:pivot])
                nested_state_args = sparse.deserialize_sparse_tensors(
                    nested_state_args, self._state_types, self._state_shapes,
                    self._state_classes)
                print(input_dataset.output_classes)
                nested_input_args = nest.pack_sequence_as(
                    input_dataset.output_types, args[pivot:])
                nested_input_args = sparse.deserialize_sparse_tensors(
                    nested_input_args, input_dataset.output_types,
                    input_dataset.output_shapes, input_dataset.output_classes)

                ret = scan_func(nested_state_args, nested_input_args)
                if not isinstance(ret, collections.Sequence) or len(ret) != 2:
                    raise TypeError(
                        "The scan function must return a pair comprising the "
                        "new state and the output value.")

                # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
                # values to tensors.
                ret = nest.pack_sequence_as(ret, [
                    sparse_tensor.SparseTensor.from_value(t)
                    if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
                    for t in nest.flatten(ret)
                ])
                new_state, output_value = ret

                # Extract and validate class information from the returned values.
                for t, clazz in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_classes)):
                    if not isinstance(t, clazz):
                        raise TypeError(
                            "The element classes for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_classes,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [type(t) for t in nest.flatten(new_state)])))
                self._output_classes = sparse.get_classes(output_value)

                # Extract shape information from the returned values.
                flat_new_state_shapes.extend(
                    [t.get_shape() for t in nest.flatten(new_state)])
                self._output_shapes = nest.pack_sequence_as(
                    output_value,
                    [t.get_shape() for t in nest.flatten(output_value)])

                # Extract and validate type information from the returned values.
                for t, dtype in zip(nest.flatten(new_state),
                                    nest.flatten(self._state_types)):
                    if t.dtype != dtype:
                        raise TypeError(
                            "The element types for the new state must match the initial "
                            "state. Expected %s; got %s." %
                            (self._state_types,
                             nest.pack_sequence_as(
                                 self._state_types,
                                 [t.dtype for t in nest.flatten(new_state)])))
                self._output_types = nest.pack_sequence_as(
                    output_value,
                    [t.dtype for t in nest.flatten(output_value)])

                dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

                # Serialize any sparse tensors.
                new_state = nest.pack_sequence_as(new_state, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(new_state))
                ])
                output_value = nest.pack_sequence_as(output_value, [
                    t for t in nest.flatten(
                        sparse.serialize_sparse_tensors(output_value))
                ])
                return nest.flatten(new_state) + nest.flatten(output_value)
示例#40
0
  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s representing the next element.

    In graph mode, you should typically call this method *once* and use its
    result as the input to another computation. A typical loop will then call
    @{tf.Session.run} on the result of that computation. The loop will terminate
    when the `Iterator.get_next()` operation raises
    @{tf.errors.OutOfRangeError}. The following skeleton shows how to use
    this method when building a training loop:

    ```python
    dataset = ...  # A `tf.data.Dataset` object.
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    # Build a TensorFlow graph that does something with each element.
    loss = model_function(next_element)
    optimizer = ...  # A `tf.train.Optimizer` object.
    train_op = optimizer.minimize(loss)

    with tf.Session() as sess:
      try:
        while True:
          sess.run(train_op)
      except tf.errors.OutOfRangeError:
        pass
    ```

    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
    when you are distributing different elements to multiple devices in a single
    step. However, a common pitfall arises when users call `Iterator.get_next()`
    in each iteration of their training loop. `Iterator.get_next()` adds ops to
    the graph, and executing each op allocates resources (including threads); as
    a consequence, invoking it in every iteration of a training loop causes
    slowdown and eventual resource exhaustion. To guard against this outcome, we
    log a warning when the number of uses crosses a fixed threshold of
    suspiciousness.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)