def function_buffering_resource_get_next(function_buffer_resource, output_types, name=None): return gen_dataset_ops.function_buffering_resource_get_next( function_buffer_resource=function_buffer_resource, output_types=output_types, name=name)
def get_next(self, name=None): """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) flat_result = [] # TODO(priyag): This will fail if the input size (typically number of # batches) is not divisible by number of devices. # How do we handle that more gracefully / let the user know? for buffer_resource in self._buffering_resources: flat_ret = gen_dataset_ops.function_buffering_resource_get_next( buffer_resource, output_types=data_nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), name=name) ret = sparse.deserialize_sparse_tensors( data_nest.pack_sequence_as(self.output_types, flat_ret), self.output_types, self.output_shapes, self.output_classes) for tensor, shape in zip(data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): if isinstance(tensor, ops.Tensor): tensor.set_shape(shape) flat_result.append(ret) return nest.pack_sequence_as(self._devices, flat_result)
def get_next(self, name=None): """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) flat_result = [] # TODO(priyag): This will fail if the input size (typically number of # batches) is not divisible by number of devices. # How do we handle that more gracefully / let the user know? for buffer_resource in self._buffering_resources: flat_ret = gen_dataset_ops.function_buffering_resource_get_next( buffer_resource, output_types=data_nest.flatten(sparse.as_dense_types( self.output_types, self.output_classes)), name=name) ret = sparse.deserialize_sparse_tensors( data_nest.pack_sequence_as(self.output_types, flat_ret), self.output_types, self.output_shapes, self.output_classes) for tensor, shape in zip( data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): if isinstance(tensor, ops.Tensor): tensor.set_shape(shape) flat_result.append(ret) return nest.pack_sequence_as(self._devices, flat_result)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): ret = gen_dataset_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffering_resource, output_types=self._flat_output_types) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def get_next(self, name=None): """See @{tf.data.Iterator.get_next}.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) flat_ret = gen_dataset_ops.function_buffering_resource_get_next( self._buffering_resource, output_types=nest.flatten(sparse.as_dense_types( self.output_types, self.output_classes)), name=name) ret = sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self.output_types, flat_ret), self.output_types, self.output_shapes, self.output_classes) for tensor, shape in zip( nest.flatten(ret), nest.flatten(self.output_shapes)): if isinstance(tensor, ops.Tensor): tensor.set_shape(shape) return ret