Exemplo n.º 1
0
def function_buffering_resource_get_next(function_buffer_resource,
                                         output_types,
                                         name=None):
    return gen_dataset_ops.function_buffering_resource_get_next(
        function_buffer_resource=function_buffer_resource,
        output_types=output_types,
        name=name)
Exemplo n.º 2
0
def function_buffering_resource_get_next(function_buffer_resource,
                                         output_types,
                                         name=None):
  return gen_dataset_ops.function_buffering_resource_get_next(
      function_buffer_resource=function_buffer_resource,
      output_types=output_types,
      name=name)
    def get_next(self, name=None):
        """See `tf.data.Iterator.get_next`."""
        self._get_next_call_count += 1
        if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
            warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

        flat_result = []
        # TODO(priyag): This will fail if the input size (typically number of
        # batches) is not divisible by number of devices.
        # How do we handle that more gracefully / let the user know?
        for buffer_resource in self._buffering_resources:
            flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
                buffer_resource,
                output_types=data_nest.flatten(
                    sparse.as_dense_types(self.output_types,
                                          self.output_classes)),
                name=name)

            ret = sparse.deserialize_sparse_tensors(
                data_nest.pack_sequence_as(self.output_types, flat_ret),
                self.output_types, self.output_shapes, self.output_classes)

            for tensor, shape in zip(data_nest.flatten(ret),
                                     data_nest.flatten(self.output_shapes)):
                if isinstance(tensor, ops.Tensor):
                    tensor.set_shape(shape)
            flat_result.append(ret)

        return nest.pack_sequence_as(self._devices, flat_result)
Exemplo n.º 4
0
  def get_next(self, name=None):
    """See `tf.data.Iterator.get_next`."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_result = []
    # TODO(priyag): This will fail if the input size (typically number of
    # batches) is not divisible by number of devices.
    # How do we handle that more gracefully / let the user know?
    for buffer_resource in self._buffering_resources:
      flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
          buffer_resource,
          output_types=data_nest.flatten(sparse.as_dense_types(
              self.output_types, self.output_classes)), name=name)

      ret = sparse.deserialize_sparse_tensors(
          data_nest.pack_sequence_as(self.output_types, flat_ret),
          self.output_types, self.output_shapes, self.output_classes)

      for tensor, shape in zip(
          data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
        if isinstance(tensor, ops.Tensor):
          tensor.set_shape(shape)
      flat_result.append(ret)

    return nest.pack_sequence_as(self._devices, flat_result)
Exemplo n.º 5
0
 def _next_internal(self):
   """Returns a nested structure of `tf.Tensor`s containing the next element.
   """
   # This runs in sync mode as iterators use an error status to communicate
   # that there is no more data to iterate over.
   # TODO(b/77291417): Fix
   with context.execution_mode(context.SYNC):
     with ops.device(self._device):
       ret = gen_dataset_ops.function_buffering_resource_get_next(
           function_buffer_resource=self._buffering_resource,
           output_types=self._flat_output_types)
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(self._output_types, ret), self._output_types,
         self._output_shapes, self._output_classes)
Exemplo n.º 6
0
  def get_next(self, name=None):
    """See @{tf.data.Iterator.get_next}."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
        self._buffering_resource,
        output_types=nest.flatten(sparse.as_dense_types(
            self.output_types, self.output_classes)), name=name)

    ret = sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self.output_types, flat_ret),
        self.output_types, self.output_shapes, self.output_classes)

    for tensor, shape in zip(
        nest.flatten(ret), nest.flatten(self.output_shapes)):
      if isinstance(tensor, ops.Tensor):
        tensor.set_shape(shape)

    return ret