def tf_reduce_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes( input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) nested_state_args = nest.pack_sequence_as( self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) nested_input_args = nest.pack_sequence_as( input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = reduce_func(nested_state_args, nested_input_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) # Extract shape information from the returned values. flat_new_state = nest.flatten(ret) flat_new_state_shapes.extend( [t.get_shape() for t in flat_new_state]) # Extract and validate type information from the returned values. for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in flat_new_state]))) # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def tf_reduce_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) nested_state_args = nest.pack_sequence_as(self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) nested_input_args = nest.pack_sequence_as(input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = reduce_func(nested_state_args, nested_input_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) # Extract shape information from the returned values. flat_new_state = nest.flatten(ret) flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) # Extract and validate type information from the returned values. for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as(self._state_types, [t.dtype for t in flat_new_state]))) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ if not context.executing_eagerly(): with ops.device(self._device): ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return self._structure._from_compatible_tensor_list(ret) # pylint: disable=protected-access # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` # because in eager mode this code will run synchronously on the calling # thread. Therefore we do not need to make a defensive context switch # to a background thread, and can achieve a small constant performance # boost by invoking the iterator synchronously. ret = gen_dataset_ops.iterator_get_next_sync( self._iterator_resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s containing the next element. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ self._get_next_call_count += 1 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten( sparse.as_dense_types( self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes( self._output_shapes, self._output_classes)), name=name)), self._output_types, self._output_shapes, self._output_classes)
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ with ops.device(self._device): if self._buffer_resource_handle is not None: ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) else: # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` # because in eager mode this code will run synchronously on the calling # thread. Therefore we do not need to make a defensive context switch # to a background thread, and can achieve a small constant performance # boost by invoking the iterator synchronously. ret = gen_dataset_ops.iterator_get_next_sync( self._resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ with ops.device(self._device): if self._buffer_resource_handle is not None: ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) else: # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` # because in eager mode this code will run synchronously on the calling # thread. Therefore we do not need to make a defensive context switch # to a background thread, and can achieve a small constant performance # boost by invoking the iterator synchronously. ret = gen_dataset_ops.iterator_get_next_sync( self._resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def testSerializeManyDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: classes = sparse.get_classes(expected) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_many_sparse_tensors(expected), types, shapes, sparse.get_classes(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def get_next(self, name=None): """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) flat_result = [] # TODO(priyag): This will fail if the input size (typically number of # batches) is not divisible by number of devices. # How do we handle that more gracefully / let the user know? for buffer_resource in self._buffering_resources: flat_ret = ged_ops.experimental_function_buffering_resource_get_next( buffer_resource, output_types=data_nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), name=name) ret = sparse.deserialize_sparse_tensors( data_nest.pack_sequence_as(self.output_types, flat_ret), self.output_types, self.output_shapes, self.output_classes) for tensor, shape in zip(data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): if isinstance(tensor, ops.Tensor): tensor.set_shape(shape) flat_result.append(ret) return nest.pack_sequence_as(self._devices, flat_result)
def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret) if ret.dtype != dtypes.int64 or ret.get_shape( ) != tensor_shape.scalar(): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) return ret
def get_next(self, name=None): """See `tf.data.Iterator.get_next`.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) flat_result = [] # TODO(priyag): This will fail if the input size (typically number of # batches) is not divisible by number of devices. # How do we handle that more gracefully / let the user know? for buffer_resource in self._buffering_resources: flat_ret = gen_dataset_ops.function_buffering_resource_get_next( buffer_resource, output_types=data_nest.flatten(sparse.as_dense_types( self.output_types, self.output_classes)), name=name) ret = sparse.deserialize_sparse_tensors( data_nest.pack_sequence_as(self.output_types, flat_ret), self.output_types, self.output_shapes, self.output_classes) for tensor, shape in zip( data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): if isinstance(tensor, ops.Tensor): tensor.set_shape(shape) flat_result.append(ret) return nest.pack_sequence_as(self._devices, flat_result)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret) if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access return ret
def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access dataset = map_func(*nested_args) else: dataset = map_func(nested_args) if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`map_func` must return a `Dataset` object.") self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes return dataset._as_variant_tensor() # pylint: disable=protected-access
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor(indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), sparse.get_sparse_types(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) if ret.dtype != dtypes.int64: raise ValueError( "`key_func` must return a single tf.int64 tensor.") dataset_ops._warn_if_collections( "tf.contrib.data.group_by_window()") # pylint: disable=protected-access return ret
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s containing the next element. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ self._get_next_call_count += 1 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as( self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)), name=name)), self._output_types, self._output_shapes, self._output_classes)
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) return nest.flatten(ret)
def testSerializeManyDeserialize(self, input_fn): test_case = input_fn() classes = sparse.get_classes(test_case) shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None), classes) types = nest.map_structure(lambda _: dtypes.int32, classes) actual = sparse.deserialize_sparse_tensors( sparse.serialize_many_sparse_tensors(test_case), types, shapes, sparse.get_classes(test_case)) nest.assert_same_structure(test_case, actual) for a, e in zip(nest.flatten(actual), nest.flatten(test_case)): self.assertSparseValuesEqual(a, e)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ if self._buffer_resource_handle is not None: with ops.device(self._device): ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes) else: return super(Iterator, self)._next_internal()
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ if self._buffer_resource_handle is not None: with ops.device(self._device): ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes) else: return super(Iterator, self)._next_internal()
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): ret = ged_ops.experimental_function_buffering_resource_get_next( function_buffer_resource=self._buffering_resource, output_types=self._flat_output_types) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): ret = gen_dataset_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffering_resource, output_types=self._flat_output_types) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def get_single_element(dataset): """Returns the single element in `dataset` as a nested structure of tensors. This function enables you to use a @{tf.data.Dataset} in a stateless "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. This can be useful when your preprocessing transformations are expressed as a `Dataset`, and you want to use the transformation at serving time. For example: ```python input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) def preprocessing_fn(input_str): # ... return image, label dataset = (tf.data.Dataset.from_tensor_slices(input_batch) .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) .batch(BATCH_SIZE)) image_batch, label_batch = tf.contrib.data.get_single_element(dataset) ``` Args: dataset: A @{tf.data.Dataset} object containing a single element. Returns: A nested structure of @{tf.Tensor} objects, corresponding to the single element of `dataset`. Raises: TypeError: if `dataset` is not a `tf.data.Dataset` object. InvalidArgumentError (at runtime): if `dataset` does not contain exactly one element. """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") nested_ret = nest.pack_sequence_as( dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access output_types=nest.flatten(sparse.as_dense_types( dataset.output_types, dataset.output_classes)), output_shapes=nest.flatten(sparse.as_dense_shapes( dataset.output_shapes, dataset.output_classes)))) return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes)
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: return sparse.deserialize_sparse_tensors( nest.pack_sequence_as( self._output_types, gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)))), self._output_types, self._output_shapes, self._output_classes)
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s containing the next element. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ return sparse.deserialize_sparse_tensors( nest.pack_sequence_as( self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten( sparse.unwrap_sparse_types(self._output_types)), output_shapes=nest.flatten(self._output_shapes), name=name)), self._output_types)
def get_value(self, name=None): # TODO(b/110122868): Consolidate the restructuring logic with similar logic # in `Iterator.get_next()` and `StructuredFunctionWrapper`. with ops.name_scope(name, "OptionalGetValue", [self._variant_tensor]) as scope: return sparse.deserialize_sparse_tensors( nest.pack_sequence_as( self._output_types, gen_dataset_ops.optional_get_value( self._variant_tensor, name=scope, output_types=nest.flatten( sparse.as_dense_types(self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)))), self._output_types, self._output_shapes, self._output_classes)
def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) if ret.dtype != dtypes.int64: raise ValueError("`key_func` must return a single tf.int64 tensor.") return ret
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s containing the next element. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten( sparse.unwrap_sparse_types( self._output_types)), output_shapes=nest.flatten( self._output_shapes), name=name)), self._output_types)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO (b/77291417): Fix id:669 # https://github.com/imdone/tensorflow/issues/670 with context.execution_mode(context.SYNC): if self._buffer_resource_handle is not None: with ops.device(self._device): ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes) else: return super(Iterator, self)._next_internal()
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ with ops.device(self._device): if self._buffer_resource_handle is not None: ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) else: # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. ret = gen_dataset_ops.iterator_get_next( self._resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ with ops.device(self._device): if self._buffer_resource_handle is not None: ret = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=self._buffer_resource_handle, output_types=self._flat_output_types) else: # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. ret = gen_dataset_ops.iterator_get_next( self._resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def get_next(self, name=None): """See @{tf.data.Iterator.get_next}.""" self._get_next_call_count += 1 if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) flat_ret = gen_dataset_ops.function_buffering_resource_get_next( self._buffering_resource, output_types=nest.flatten(sparse.as_dense_types( self.output_types, self.output_classes)), name=name) ret = sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self.output_types, flat_ret), self.output_types, self.output_shapes, self.output_classes) for tensor, shape in zip( nest.flatten(ret), nest.flatten(self.output_shapes)): if isinstance(tensor, ops.Tensor): tensor.set_shape(shape) return ret
def tf_finalize_func(*args): """A wrapper for Defun that facilitates shape inference.""" for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes))): arg.set_shape(shape) nested_args = nest.pack_sequence_as(self._state_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, self._state_types, self._state_shapes, self._state_classes) ret = finalize_func(nested_args) # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) self._output_classes = sparse.get_classes(ret) self._output_shapes = nest.pack_sequence_as( ret, [t.get_shape() for t in nest.flatten(ret)]) self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) dataset_ops._warn_if_collections( "tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access # Serialize any sparse tensors. ret = nest.pack_sequence_as(ret, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) ]) return nest.flatten(ret)
def testSerializeDeserialize(self): test_cases = ( (), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()), ((), sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1])), ) for expected in test_cases: actual = sparse.deserialize_sparse_tensors( sparse.serialize_sparse_tensors(expected), sparse.get_sparse_types(expected)) nest.assert_same_structure(expected, actual) for a, e in zip(nest.flatten(actual), nest.flatten(expected)): self.assertSparseValuesEqual(a, e)
def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ # This runs in sync mode as iterators use an error status to communicate # that there is no more data to iterate over. # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): # TODO(ashankar): Consider removing this ops.device() contextmanager # and instead mimic ops placement in graphs: Operations on resource # handles execute on the same device as where the resource is placed. # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` # because in eager mode this code will run synchronously on the calling # thread. Therefore we do not need to make a defensive context switch # to a background thread, and can achieve a small constant performance # boost by invoking the iterator synchronously. ret = gen_dataset_ops.iterator_get_next_sync( self._resource, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, ret), self._output_types, self._output_shapes, self._output_classes)
def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as(self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as(input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip( nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip( nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value)
def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. for arg, shape in zip( args, nest.flatten( sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + nest.flatten( sparse.as_dense_shapes( input_dataset.output_shapes, input_dataset.output_classes))): arg.set_shape(shape) pivot = len(nest.flatten(self._state_shapes)) print(self._state_classes) nested_state_args = nest.pack_sequence_as( self._state_types, args[:pivot]) nested_state_args = sparse.deserialize_sparse_tensors( nested_state_args, self._state_types, self._state_shapes, self._state_classes) print(input_dataset.output_classes) nested_input_args = nest.pack_sequence_as( input_dataset.output_types, args[pivot:]) nested_input_args = sparse.deserialize_sparse_tensors( nested_input_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError( "The scan function must return a pair comprising the " "new state and the output value.") # Convert any `SparseTensorValue`s to `SparseTensor`s and all other # values to tensors. ret = nest.pack_sequence_as(ret, [ sparse_tensor.SparseTensor.from_value(t) if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) for t in nest.flatten(ret) ]) new_state, output_value = ret # Extract and validate class information from the returned values. for t, clazz in zip(nest.flatten(new_state), nest.flatten(self._state_classes)): if not isinstance(t, clazz): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, nest.pack_sequence_as( self._state_types, [type(t) for t in nest.flatten(new_state)]))) self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. flat_new_state_shapes.extend( [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. for t, dtype in zip(nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(new_state)) ]) output_value = nest.pack_sequence_as(output_value, [ t for t in nest.flatten( sparse.serialize_sparse_tensors(output_value)) ]) return nest.flatten(new_state) + nest.flatten(output_value)
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s representing the next element. In graph mode, you should typically call this method *once* and use its result as the input to another computation. A typical loop will then call @{tf.Session.run} on the result of that computation. The loop will terminate when the `Iterator.get_next()` operation raises @{tf.errors.OutOfRangeError}. The following skeleton shows how to use this method when building a training loop: ```python dataset = ... # A `tf.data.Dataset` object. iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() # Build a TensorFlow graph that does something with each element. loss = model_function(next_element) optimizer = ... # A `tf.train.Optimizer` object. train_op = optimizer.minimize(loss) with tf.Session() as sess: try: while True: sess.run(train_op) except tf.errors.OutOfRangeError: pass ``` NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. when you are distributing different elements to multiple devices in a single step. However, a common pitfall arises when users call `Iterator.get_next()` in each iteration of their training loop. `Iterator.get_next()` adds ops to the graph, and executing each op allocates resources (including threads); as a consequence, invoking it in every iteration of a training loop causes slowdown and eventual resource exhaustion. To guard against this outcome, we log a warning when the number of uses crosses a fixed threshold of suspiciousness. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ self._get_next_call_count += 1 if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) return sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten( sparse.as_dense_types( self._output_types, self._output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes( self._output_shapes, self._output_classes)), name=name)), self._output_types, self._output_shapes, self._output_classes)