def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret) if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access return ret
def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access dataset = map_func(*nested_args) else: dataset = map_func(nested_args) if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`map_func` must return a `Dataset` object.") self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes return dataset._as_variant_tensor() # pylint: disable=protected-access
def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) if ret.dtype != dtypes.int64: raise ValueError( "`key_func` must return a single tf.int64 tensor.") dataset_ops._warn_if_collections( "tf.contrib.data.group_by_window()") # pylint: disable=protected-access return ret
def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, input_dataset.output_classes) for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) nested_args = sparse.deserialize_sparse_tensors( nested_args, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret) if ret.dtype != dtypes.int64 or ret.get_shape( ) != tensor_shape.scalar(): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) return ret
def __init__(self, input_dataset, row_splits_dtype): """Constructs a new _DenseToRaggedDataset. Args: input_dataset: The dataset whose tf.Tensor elements should be made ragged. row_splits_dtype: The dtype that should be used for the `row_splits` of any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* have their row_splits dtype changed. """ # Replace each TensorSpec in the input dataset's structure with a # corresponding RaggedTensorSpec. def to_ragged_spec(spec): """Returns the new spec based on RaggedTensors.""" if (not isinstance(spec, tensor_spec.TensorSpec) or spec.shape.rank is None or spec.shape.is_fully_defined()): return spec else: ragged_rank = max([ axis for (axis, size) in enumerate(spec.shape.as_list()) if size is None ]) return ragged_tensor.RaggedTensorSpec( shape=spec.shape, dtype=spec.dtype, ragged_rank=ragged_rank, row_splits_dtype=row_splits_dtype) self._structure = nest.map_structure(to_ragged_spec, input_dataset.element_spec) # Replace each tf.Tensor value in the input dataset with a variant-encoded # RaggedTensor. Since we're updating the corresponding structure to be # a RaggedTensorSpec, this variant-encoded tensor will be decoded with # RaggedTensorSpec._from_tensor_list. def to_ragged_variant(value): """Re-encode Tensors as RaggedTensors.""" if (not isinstance(value, ops.Tensor) or value.shape.rank is None or value.shape.is_fully_defined()): return value else: spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value)) if spec._ragged_rank > 0: # pylint: disable=protected-access value = ragged_tensor.RaggedTensor.from_tensor( value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access return spec._to_tensor_list(value)[0] # pylint: disable=protected-access # Tuples are automatically unpacked by `dataset.map` so we repack them. if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access map_fn = lambda *value: nest.map_structure(to_ragged_variant, value) else: map_fn = lambda value: nest.map_structure(to_ragged_variant, value) self._mapped_dataset = input_dataset.map(map_fn) variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
def __init__(self, input_dataset, row_splits_dtype): """Constructs a new _DenseToRaggedDataset. Args: input_dataset: The dataset whose tf.Tensor elements should be made ragged. row_splits_dtype: The dtype that should be used for the `row_splits` of any new ragged tensors. Existing `tf.RaggedTensor` elements do *not* have their row_splits dtype changed. """ # Replace each TensorSpec in the input dataset's structure with a # corresponding RaggedTensorSpec. def to_ragged_spec(spec): if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0: return ragged_tensor.RaggedTensorSpec( shape=spec.shape, dtype=spec.dtype, ragged_rank=0, row_splits_dtype=row_splits_dtype) else: return spec self._structure = nest.map_structure(to_ragged_spec, input_dataset.element_spec) # Replace each tf.Tensor value in the input dataset with a variant-encoded # RaggedTensor. Since we're updating the corresponding structure to be # a RaggedTensorSpec, this variant-encoded tensor will be decoded with # RaggedTensorSpec._from_tensor_list. def to_ragged_variant(value): if isinstance(value, ops.Tensor) and value.shape.ndims != 0: spec = to_ragged_spec( tensor_spec.TensorSpec.from_tensor(value)) return spec._to_tensor_list(value)[0] # pylint: disable=protected-access else: return value # Tuples are automatically unpacked by `dataset.map` so we repack them. if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access map_fn = lambda *value: nest.map_structure(to_ragged_variant, value ) else: map_fn = lambda value: nest.map_structure(to_ragged_variant, value) self._mapped_dataset = input_dataset.map(map_fn) variant = self._mapped_dataset._variant_tensor # pylint: disable=protected-access super(_DenseToRaggedDataset, self).__init__(input_dataset, variant)
def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) # pylint: enable=protected-access else: ret = key_func(nested_args) ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) if ret.dtype != dtypes.int64: raise ValueError("`key_func` must return a single tf.int64 tensor.") return ret