def testConvertLegacyStructureFail(self): with self.assertRaisesRegex( TypeError, "Could not build a structure for output class " "_EagerTensorArray. Make sure any component class in " "`output_classes` inherits from one of the following classes: " "`tf.TypeSpec`, `tf.sparse.SparseTensor`, `tf.Tensor`, " "`tf.TensorArray`."): structure.convert_legacy_structure( dtypes.int32, tensor_shape.TensorShape([2, None]), tensor_array_ops._EagerTensorArray)
def __init__(self, input_dataset): """See `unbatch()` for more details.""" input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) flat_shapes = nest.flatten(input_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError( "Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = structure.convert_legacy_structure( dataset_ops.get_legacy_output_types(input_dataset), nest.map_structure(lambda s: s[1:], input_shapes), dataset_ops.get_legacy_output_classes(input_dataset)) variant_tensor = ged_ops.experimental_unbatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, iterator_resource, initializer, output_types, output_shapes, output_classes): """Creates a new iterator from the given iterator resource. Note: Most users will not call this initializer directly, and will instead use `Dataset.make_initializable_iterator()` or `Dataset.make_one_shot_iterator()`. Args: iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the iterator. initializer: A `tf.Operation` that should be run to initialize this iterator. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this iterator. output_shapes: A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this iterator. output_classes: A nested structure of Python `type` objects corresponding to each component of an element of this iterator. """ self._iterator_resource = iterator_resource self._initializer = initializer if (output_types is None or output_shapes is None or output_classes is None): raise ValueError("If `structure` is not specified, all of " "`output_types`, `output_shapes`, and `output_classes`" " must be specified.") self._structure = structure_lib.convert_legacy_structure( output_types, output_shapes, output_classes) self._string_handle = gen_dataset_ops.iterator_to_string_handle( self._iterator_resource) self._get_next_call_count = 0 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
def __init__(self, pipeline='', batch_size=1, num_threads=4, device_id=0, exec_separated=False, prefetch_queue_depth=2, cpu_prefetch_queue_depth=2, gpu_prefetch_queue_depth=2, shapes=[], dtypes=[]): assert (len(shapes) == len(dtypes), "Different number of provided shapes and dtypes.") output_classes = tuple(ops.Tensor for shape in shapes) self._pipeline = pipeline.serialize() self._batch_size = batch_size self._num_threads = num_threads self._device_id = device_id self._exec_separated = exec_separated self._prefetch_queue_depth = prefetch_queue_depth self._cpu_prefetch_queue_depth = cpu_prefetch_queue_depth self._gpu_prefetch_queue_depth = gpu_prefetch_queue_depth self._shapes = tuple(tf.TensorShape(shape) for shape in shapes) self._dtypes = tuple(dtype for dtype in dtypes) self._structure = structure.convert_legacy_structure( self._dtypes, self._shapes, output_classes) super(_DALIDatasetV2, self).__init__(self._as_variant_tensor())
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d for d in output_shapes.dims] output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor(window_size, dtype=dtypes.int64, name="window_stride") self._window_stride = ops.convert_to_tensor(window_stride, dtype=dtypes.int64, name="window_stride") self._window_shift = ops.convert_to_tensor(window_shift, dtype=dtypes.int64, name="window_shift") input_structure = structure.convert_legacy_structure( input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) self._structure = input_structure._batch(None) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_sliding_window_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **dataset_ops.flat_structure(self)) super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, make_variant_fn, columns, output_types, output_shapes=None, batch_size=None, batch_mode='keep_remainder'): self._columns = columns self._structure = structure_lib.convert_legacy_structure( output_types, output_shapes or nest.map_structure( lambda _: tf.TensorShape(None), output_types), nest.map_structure(lambda _: tf.Tensor, output_types)) self._batch_size = tf.convert_to_tensor(batch_size or 0, dtype=dtypes.int64, name="batch_size") if batch_mode not in self.batch_modes_supported: raise ValueError( "Unsupported batch_mode: '{}', must be one of {}".format( batch_mode, self.batch_modes_supported)) self._batch_mode = tf.convert_to_tensor(batch_mode, dtypes.string, name="batch_mode") if batch_size is not None or batch_mode == 'auto': spec_batch_size = batch_size if batch_mode == 'drop_remainder' else None # pylint: disable=protected-access self._structure = nest.map_structure( lambda component_spec: component_spec._batch(spec_batch_size), self._structure) print(self._flat_structure) variant_tensor = make_variant_fn(columns=self._columns, batch_size=self._batch_size, batch_mode=self._batch_mode, **self._flat_structure) super(ArrowBaseDataset, self).__init__(variant_tensor)
def __init__(self, iterator_resource, initializer, output_types, output_shapes, output_classes): """Creates a new iterator from the given iterator resource. Note: Most users will not call this initializer directly, and will instead use `Dataset.make_initializable_iterator()` or `Dataset.make_one_shot_iterator()`. Args: iterator_resource: A `tf.resource` scalar `tf.Tensor` representing the iterator. initializer: A `tf.Operation` that should be run to initialize this iterator. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this iterator. output_shapes: A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this iterator. output_classes: A nested structure of Python `type` objects corresponding to each component of an element of this iterator. """ self._iterator_resource = iterator_resource self._initializer = initializer if (output_types is None or output_shapes is None or output_classes is None): raise ValueError("If `structure` is not specified, all of " "`output_types`, `output_shapes`, and `output_classes`" " must be specified.") self._structure = structure_lib.convert_legacy_structure( output_types, output_shapes, output_classes) self._string_handle = gen_dataset_ops.iterator_to_string_handle( self._iterator_resource) self._get_next_call_count = 0 ops.add_to_collection(GLOBAL_ITERATORS, self._iterator_resource)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension.") if (tensor_shape.dimension_value(output_shapes[0]) and tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): raise errors.InvalidArgumentError( None, None, "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes( self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes( self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, selector_input, data_inputs): self._selector_input = selector_input self._data_inputs = list(data_inputs) first_output_types = dataset_ops.get_legacy_output_types( data_inputs[0]) first_output_classes = dataset_ops.get_legacy_output_classes( data_inputs[0]) for data_input in data_inputs[1:]: if (dataset_ops.get_legacy_output_types(data_input) != first_output_types or dataset_ops.get_legacy_output_classes(data_input) != first_output_classes): raise TypeError( "All datasets must have the same type and class.") output_shapes = dataset_ops.get_legacy_output_shapes( self._data_inputs[0]) for data_input in self._data_inputs[1:]: output_shapes = nest.pack_sequence_as(output_shapes, [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( nest.flatten(output_shapes), nest.flatten( dataset_ops.get_legacy_output_shapes(data_input))) ]) self._element_spec = structure.convert_legacy_structure( first_output_types, output_shapes, first_output_classes) super(_DirectedInterleaveDataset, self).__init__()
def __init__(self, selector_input, data_inputs): self._selector_input = selector_input self._data_inputs = list(data_inputs) first_output_types = dataset_ops.get_legacy_output_types( data_inputs[0]) first_output_classes = dataset_ops.get_legacy_output_classes( data_inputs[0]) for data_input in data_inputs[1:]: if (dataset_ops.get_legacy_output_types(data_input) != first_output_types or dataset_ops.get_legacy_output_classes(data_input) != first_output_classes): raise TypeError( "All datasets must have the same type and class.") output_shapes = dataset_ops.get_legacy_output_shapes( self._data_inputs[0]) for data_input in self._data_inputs[1:]: output_shapes = nest.pack_sequence_as(output_shapes, [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( nest.flatten(output_shapes), nest.flatten( dataset_ops.get_legacy_output_shapes(data_input))) ]) self._element_spec = structure.convert_legacy_structure( first_output_types, output_shapes, first_output_classes) # pylint: disable=protected-access variant_tensor = gen_experimental_dataset_ops.directed_interleave_dataset( self._selector_input._variant_tensor, [data_input._variant_tensor for data_input in self._data_inputs], **self._flat_structure) super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_workers.""" if len(output_shapes) < 1: raise ValueError("Input shape should have at least one dimension.") if (tensor_shape.dimension_value(output_shapes[0]) and tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0): raise errors.InvalidArgumentError( None, None, "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._structure = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_replicas, use_fallback=True): self._input_dataset = input_dataset def recalculate_output_shapes(output_shapes): """Recalculates the output_shapes after dividing it by num_replicas.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d.value for d in output_shapes.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: output_dims[0] = output_dims[0] // num_replicas else: # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. output_dims[0] = None return tensor_shape.TensorShape(output_dims) input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_shapes = dataset_ops.get_legacy_output_shapes( self._input_dataset) input_classes = dataset_ops.get_legacy_output_classes( self._input_dataset) output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes) self._element_spec = structure.convert_legacy_structure( input_types, output_shapes, input_classes) variant_tensor = ged_ops.rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, features, num_parallel_calls): self._input_dataset = input_dataset if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access structure.TensorStructure(dtypes.string, [None])): raise TypeError("Input dataset should be a dataset of vectors of strings") self._num_parallel_calls = num_parallel_calls # pylint: disable=protected-access self._features = parsing_ops._prepend_none_dimension(features) # sparse_keys and dense_keys come back sorted here. (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes) = parsing_ops._features_to_raw_params( self._features, [ parsing_ops.VarLenFeature, parsing_ops.SparseFeature, parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature ]) # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, dense_shape_as_shape) = parsing_ops._process_raw_parameters( None, dense_defaults, sparse_keys, sparse_types, dense_keys, dense_types, dense_shapes) # pylint: enable=protected-access self._sparse_keys = sparse_keys self._sparse_types = sparse_types self._dense_keys = dense_keys self._dense_defaults = dense_defaults_vec self._dense_shapes = dense_shapes self._dense_types = dense_types input_dataset_shape = dataset_ops.get_legacy_output_shapes( self._input_dataset) dense_output_shapes = [input_dataset_shape.concatenate(shape) for shape in dense_shape_as_shape] sparse_output_shapes = [input_dataset_shape.concatenate([None]) for _ in range(len(sparse_keys))] output_shapes = dict( zip(self._dense_keys + self._sparse_keys, dense_output_shapes + sparse_output_shapes)) output_types = dict( zip(self._dense_keys + self._sparse_keys, self._dense_types + self._sparse_types)) output_classes = dict( zip(self._dense_keys + self._sparse_keys, [ops.Tensor for _ in range(len(self._dense_defaults))] + [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) ])) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) variant_tensor = ( gen_experimental_dataset_ops.experimental_parse_example_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._num_parallel_calls, self._dense_defaults, self._sparse_keys, self._dense_keys, self._sparse_types, self._dense_shapes, **dataset_ops.flat_structure(self))) super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue( expected_structure.is_compatible_with(actual_structure)) self.assertTrue( actual_structure.is_compatible_with(expected_structure))
def __init__(self, pipeline, output_dtypes=None, output_shapes=None, fail_on_device_mismatch=True, *, batch_size=1, num_threads=4, device_id=0, exec_separated=False, prefetch_queue_depth=2, cpu_prefetch_queue_depth=2, gpu_prefetch_queue_depth=2, dtypes=None, shapes=None): output_shapes = self._handle_deprecation(output_shapes, shapes, "shapes") output_dtypes = self._handle_deprecation(output_dtypes, dtypes, "dtypes") if not self._check_output_dtypes(output_dtypes): raise TypeError(("`output_dtypes` should be provided as single tf.DType value " + "or a tuple of tf.DType values. Got value `{}` of type `{}`.") \ .format(output_dtypes, type(output_dtypes))) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_dtypes) else: output_shapes = nest.map_structure_up_to( output_dtypes, tensor_shape.as_shape, output_shapes) if not isinstance(output_dtypes, tuple): output_dtypes = (output_dtypes, ) output_shapes = (output_shapes, ) output_classes = nest.map_structure(lambda _: ops.Tensor, output_dtypes) self._pipeline = serialize_pipeline(pipeline) self._batch_size = batch_size self._num_threads = num_threads if device_id is None: device_id = types.CPU_ONLY_DEVICE_ID self._device_id = device_id self._exec_separated = exec_separated self._prefetch_queue_depth = prefetch_queue_depth self._cpu_prefetch_queue_depth = cpu_prefetch_queue_depth self._gpu_prefetch_queue_depth = gpu_prefetch_queue_depth self._output_shapes = output_shapes self._output_dtypes = output_dtypes self._fail_on_device_mismatch = fail_on_device_mismatch self._structure = structure.convert_legacy_structure( self._output_dtypes, self._output_shapes, output_classes) super(_DALIDatasetV2, self).__init__(self._as_variant_tensor())
def __init__(self, cache_name, host="localhost", port=10800, local=False, part=-1, page_size=100, username=None, password=None, certfile=None, keyfile=None, cert_password=None): """Create a IgniteDataset. Args: cache_name: Cache name to be used as datasource. host: Apache Ignite Thin Client host to be connected. port: Apache Ignite Thin Client port to be connected. local: Local flag that defines to query only local data. part: Number of partitions to be queried. page_size: Apache Ignite Thin Client page size. username: Apache Ignite Thin Client authentication username. password: Apache Ignite Thin Client authentication password. certfile: File in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity. keyfile: File containing the private key (otherwise the private key will be taken from certfile as well). cert_password: Password to be used if the private key is encrypted and a password is necessary. """ with IgniteClient(host, port, username, password, certfile, keyfile, cert_password) as client: client.handshake() self.cache_type = client.get_cache_type(cache_name) self.cache_name = ops.convert_to_tensor( cache_name, dtype=dtypes.string, name="cache_name") self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host") self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port") self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local") self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part") self.page_size = ops.convert_to_tensor( page_size, dtype=dtypes.int32, name="page_size") self.schema = ops.convert_to_tensor( self.cache_type.to_flat(), dtype=dtypes.int32, name="schema") self.permutation = ops.convert_to_tensor( self.cache_type.to_permutation(), dtype=dtypes.int32, name="permutation") self._structure = structure.convert_legacy_structure( self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), self.cache_type.to_output_classes()) super(IgniteDataset, self).__init__(self._as_variant_tensor())
def __init__(self, cache_name, host="localhost", port=10800, local=False, part=-1, page_size=100, username=None, password=None, certfile=None, keyfile=None, cert_password=None): """Create a IgniteDataset. Args: cache_name: Cache name to be used as datasource. host: Apache Ignite Thin Client host to be connected. port: Apache Ignite Thin Client port to be connected. local: Local flag that defines to query only local data. part: Number of partitions to be queried. page_size: Apache Ignite Thin Client page size. username: Apache Ignite Thin Client authentication username. password: Apache Ignite Thin Client authentication password. certfile: File in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity. keyfile: File containing the private key (otherwise the private key will be taken from certfile as well). cert_password: Password to be used if the private key is encrypted and a password is necessary. """ with IgniteClient(host, port, username, password, certfile, keyfile, cert_password) as client: client.handshake() self.cache_type = client.get_cache_type(cache_name) self.cache_name = ops.convert_to_tensor( cache_name, dtype=dtypes.string, name="cache_name") self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host") self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port") self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local") self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part") self.page_size = ops.convert_to_tensor( page_size, dtype=dtypes.int32, name="page_size") self.schema = ops.convert_to_tensor( self.cache_type.to_flat(), dtype=dtypes.int32, name="schema") self.permutation = ops.convert_to_tensor( self.cache_type.to_permutation(), dtype=dtypes.int32, name="permutation") self._element_spec = structure.convert_legacy_structure( self.cache_type.to_output_types(), self.cache_type.to_output_shapes(), self.cache_type.to_output_classes()) super(IgniteDataset, self).__init__(self._as_variant_tensor())
def __init__(self, selector_input, data_inputs, stop_on_empty_dataset=False): self._selector_input = selector_input self._data_inputs = list(data_inputs) self._stop_on_empty_dataset = stop_on_empty_dataset first_output_types = dataset_ops.get_legacy_output_types( data_inputs[0]) first_output_classes = dataset_ops.get_legacy_output_classes( data_inputs[0]) for i, data_input in enumerate(data_inputs[1:]): if (dataset_ops.get_legacy_output_types(data_input) != first_output_types or dataset_ops.get_legacy_output_classes(data_input) != first_output_classes): raise TypeError( "All datasets must have the same type and class.\n" "dataset 0 vs dataset %s types: %s ; %s\n" "classes: %s ; %s" % (i + 1, first_output_types, dataset_ops.get_legacy_output_types(data_input), first_output_classes, dataset_ops.get_legacy_output_classes(data_input))) output_shapes = dataset_ops.get_legacy_output_shapes( self._data_inputs[0]) for data_input in self._data_inputs[1:]: output_shapes = nest.pack_sequence_as(output_shapes, [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( nest.flatten(output_shapes), nest.flatten( dataset_ops.get_legacy_output_shapes(data_input))) ]) self._element_spec = structure.convert_legacy_structure( first_output_types, output_shapes, first_output_classes) compat_kwargs = {} if compat.forward_compatible(2021, 5, 14) or self._stop_on_empty_dataset: compat_kwargs[ "stop_on_empty_dataset"] = self._stop_on_empty_dataset # pylint: disable=protected-access variant_tensor = ( gen_experimental_dataset_ops.directed_interleave_dataset( self._selector_input._variant_tensor, [ data_input._variant_tensor for data_input in self._data_inputs ], **compat_kwargs, **self._flat_structure)) super(_DirectedInterleaveDataset, self).__init__(variant_tensor)
def __init__(self, input_dataset, features, num_parallel_calls): super(_ParseExampleDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access structure.TensorStructure(dtypes.string, [None])): raise TypeError("Input dataset should be a dataset of vectors of strings") self._num_parallel_calls = num_parallel_calls # pylint: disable=protected-access self._features = parsing_ops._prepend_none_dimension(features) # sparse_keys and dense_keys come back sorted here. (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes) = parsing_ops._features_to_raw_params( self._features, [ parsing_ops.VarLenFeature, parsing_ops.SparseFeature, parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature ]) # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, dense_shape_as_shape) = parsing_ops._process_raw_parameters( None, dense_defaults, sparse_keys, sparse_types, dense_keys, dense_types, dense_shapes) # pylint: enable=protected-access self._sparse_keys = sparse_keys self._sparse_types = sparse_types self._dense_keys = dense_keys self._dense_defaults = dense_defaults_vec self._dense_shapes = dense_shapes self._dense_types = dense_types dense_output_shapes = [ self._input_dataset.output_shapes.concatenate(shape) for shape in dense_shape_as_shape ] sparse_output_shapes = [ self._input_dataset.output_shapes.concatenate([None]) for _ in range(len(sparse_keys)) ] output_shapes = dict( zip(self._dense_keys + self._sparse_keys, dense_output_shapes + sparse_output_shapes)) output_types = dict( zip(self._dense_keys + self._sparse_keys, self._dense_types + self._sparse_types)) output_classes = dict( zip(self._dense_keys + self._sparse_keys, [ops.Tensor for _ in range(len(self._dense_defaults))] + [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) ])) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes)
def __init__(self, dataset): """Creates a new iterator over the given dataset. For example: ```python dataset = tf.data.Dataset.range(4) for x in Iterator(dataset): print(x) ``` Tensors produced will be placed on the device on which this iterator object was created. Args: dataset: A `tf.data.Dataset` object. Raises: RuntimeError: When invoked without eager execution enabled. """ if not context.executing_eagerly(): raise RuntimeError( "{} objects can only be used when eager execution is enabled, use " "tf.data.Dataset.make_initializable_iterator or " "tf.data.Dataset.make_one_shot_iterator for graph construction" .format(type(self))) self._device = context.context().device_name with ops.device("/cpu:0"): # pylint: disable=protected-access dataset = dataset._apply_options() ds_variant = dataset._variant_tensor self._structure = structure_lib.convert_legacy_structure( dataset.output_types, dataset.output_shapes, dataset.output_classes) self._flat_output_types = self._structure._flat_types self._flat_output_shapes = self._structure._flat_shapes with ops.colocate_with(ds_variant): self._resource = gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) # Delete the resource when this object is deleted self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device=self._device)
def __init__(self, dataset): """Creates a new iterator over the given dataset. For example: ```python dataset = tf.data.Dataset.range(4) for x in Iterator(dataset): print(x) ``` Tensors produced will be placed on the device on which this iterator object was created. Args: dataset: A `tf.data.Dataset` object. Raises: RuntimeError: When invoked without eager execution enabled. """ if not context.executing_eagerly(): raise RuntimeError( "{} objects can only be used when eager execution is enabled, use " "tf.data.Dataset.make_initializable_iterator or " "tf.data.Dataset.make_one_shot_iterator for graph construction". format(type(self))) self._device = context.context().device_name with ops.device("/cpu:0"): # pylint: disable=protected-access dataset = dataset._apply_options() ds_variant = dataset._variant_tensor self._structure = structure_lib.convert_legacy_structure( dataset.output_types, dataset.output_shapes, dataset.output_classes) self._flat_output_types = self._structure._flat_types self._flat_output_shapes = self._structure._flat_shapes with ops.colocate_with(ds_variant): self._resource = gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) # Delete the resource when this object is deleted self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device=self._device)
def __init__(self, input_dataset): """See `unbatch()` for more details.""" super(_UnbatchDataset, self).__init__(input_dataset) flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError("Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = structure.convert_legacy_structure( input_dataset.output_types, nest.map_structure(lambda s: s[1:], input_dataset.output_shapes), input_dataset.output_classes)
def __init__(self, selector_input, data_inputs): self._selector_input = selector_input self._data_inputs = list(data_inputs) for data_input in data_inputs[1:]: if (data_input.output_types != data_inputs[0].output_types or data_input.output_classes != data_inputs[0].output_classes): raise TypeError("All datasets must have the same type and class.") output_shapes = self._data_inputs[0].output_shapes for data_input in self._data_inputs[1:]: output_shapes = nest.pack_sequence_as(output_shapes, [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( nest.flatten(output_shapes), nest.flatten(data_input.output_shapes)) ]) self._structure = structure.convert_legacy_structure( data_inputs[0].output_types, output_shapes, data_inputs[0].output_classes)
def __init__(self, selector_input, data_inputs): self._selector_input = selector_input self._data_inputs = list(data_inputs) for data_input in data_inputs[1:]: if (data_input.output_types != data_inputs[0].output_types or data_input.output_classes != data_inputs[0].output_classes): raise TypeError("All datasets must have the same type and class.") output_shapes = self._data_inputs[0].output_shapes for data_input in self._data_inputs[1:]: output_shapes = nest.pack_sequence_as(output_shapes, [ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( nest.flatten(output_shapes), nest.flatten(data_input.output_shapes)) ]) self._structure = structure.convert_legacy_structure( data_inputs[0].output_types, output_shapes, data_inputs[0].output_classes) super(_DirectedInterleaveDataset, self).__init__()
def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") self._window_stride = ops.convert_to_tensor( window_stride, dtype=dtypes.int64, name="window_stride") self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") input_structure = structure.convert_legacy_structure( input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) self._structure = input_structure._batch(None) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_sliding_window_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **dataset_ops.flat_structure(self)) super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset output_shapes = input_dataset.output_shapes if len(output_shapes) < 1: raise ValueError("Input shape should have at least one dimension.") if not output_shapes.dims[0].value: raise ValueError("Cannot rebatch unknown batch size datasets.") if output_shapes.dims[0].value % num_workers != 0: raise ValueError( "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers output_shapes = tensor_shape.TensorShapeV1(output_dims) self._structure = structure.convert_legacy_structure( self._input_dataset.output_types, output_shapes, self._input_dataset.output_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, num_workers): self._input_dataset = input_dataset output_shapes = input_dataset.output_shapes if len(output_shapes) < 1: raise ValueError("Input shape should have at least one dimension.") if not output_shapes.dims[0].value: raise ValueError("Cannot rebatch unknown batch size datasets.") if output_shapes.dims[0].value % num_workers != 0: raise ValueError( "First dim of input shape: %d is not divisible by num_workers: %d" % (output_shapes[0], num_workers)) output_dims = [d for d in output_shapes.dims] output_dims[0] = output_dims[0] // num_workers output_shapes = tensor_shape.TensorShapeV1(output_dims) self._structure = structure.convert_legacy_structure( self._input_dataset.output_types, output_shapes, self._input_dataset.output_classes) variant_tensor = ged_ops.experimental_rebatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access num_workers=num_workers, **dataset_ops.flat_structure(self)) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, pipeline='', batch_size=1, num_threads=4, device_id=0, exec_separated=False, prefetch_queue_depth=2, cpu_prefetch_queue_depth=2, gpu_prefetch_queue_depth=2, shapes=[], dtypes=[]): assert (len(shapes) == len(dtypes), "Different number of provided shapes and dtypes.") output_classes = tuple(ops.Tensor for shape in shapes) self._pipeline = pipeline.serialize() self._batch_size = batch_size self._num_threads = num_threads self._device_id = device_id self._exec_separated = exec_separated self._prefetch_queue_depth = prefetch_queue_depth self._cpu_prefetch_queue_depth = cpu_prefetch_queue_depth self._gpu_prefetch_queue_depth = gpu_prefetch_queue_depth self._shapes = tuple(tf.TensorShape(shape) for shape in shapes) self._dtypes = tuple(dtype for dtype in dtypes) self._structure = structure.convert_legacy_structure( self._dtypes, self._shapes, output_classes) if _get_tf_minor_version() == '14': super(_DALIDatasetV2, self).__init__(self._as_variant_tensor()) elif _get_tf_minor_version() == '13': super(_DALIDatasetV2, self).__init__() else: raise RuntimeError( 'Unsupported TensorFlow version detected at runtime. DALIDataset supports versions: 1.13, 1.14' )
def __init__(self, input_dataset): """See `unbatch()` for more details.""" flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError("Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = structure.convert_legacy_structure( input_dataset.output_types, nest.map_structure(lambda s: s[1:], input_dataset.output_shapes), input_dataset.output_classes) variant_tensor = ged_ops.experimental_unbatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values, drop_remainder, ): """See `Dataset.batch()` for details.""" self._input_dataset = input_dataset self._batch_size = batch_size self._padded_shapes = padded_shapes self._padding_values = padding_values self._drop_remainder = drop_remainder def _padded_shape_to_batch_shape(s): return tensor_shape.TensorShape([ tensor_util.constant_value(self._batch_size) if smart_cond.smart_constant_value(self._drop_remainder) else None ]).concatenate(tensor_util.constant_value_as_shape(s)) output_shapes = nest.map_structure( _padded_shape_to_batch_shape, self._padded_shapes) self._structure = structure.convert_legacy_structure( ds.get_legacy_output_types(self._input_dataset), output_shapes, ds.get_legacy_output_classes(self._input_dataset)) variant_tensor = gen_dataset_ops.padded_batch_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access batch_size=self._batch_size, padded_shapes=[ ops.convert_to_tensor(s, dtype=dtypes.int64) for s in nest.flatten(self._padded_shapes) ], padding_values=nest.flatten(self._padding_values), drop_remainder=self._drop_remainder, output_shapes=structure.get_flat_tensor_shapes(self._structure)) super(TntPaddedBatchDataset, self).__init__(input_dataset, variant_tensor)
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a `tf.Session.run` call. In that case, `string_handle` would be a `tf.compat.v1.placeholder`, and you would feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.compat.v1.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=structure.get_flat_tensor_types( output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=structure.get_flat_tensor_types(output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None, allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. allow_unsafe_cast: (Optional.) If `True`, the caller may switch the reported output types and shapes of the restructured dataset, e.g. to switch a sparse tensor represented as `tf.variant` to its user-visible type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ self._input_dataset = dataset input_types = dataset_ops.get_legacy_output_types(dataset) if not allow_unsafe_cast: # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(input_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have " "output types %r" % (dataset_ops.get_legacy_output_types(dataset), output_types)) input_shapes = dataset_ops.get_legacy_output_shapes(dataset) if output_shapes is None: # Inherit shapes from the original `dataset`. output_shapes = nest.pack_sequence_as(output_types, nest.flatten(input_shapes)) else: if not allow_unsafe_cast: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(input_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (input_shapes, output_shapes)) output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) input_classes = dataset_ops.get_legacy_output_classes(dataset) if output_classes is None: # Inherit class types from the original `dataset`. output_classes = nest.pack_sequence_as(output_types, nest.flatten(input_classes)) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
def __init__(self, input_dataset, initial_state, scan_func): """See `scan()` for details.""" self._input_dataset = input_dataset with ops.name_scope("initial_state"): self._initial_state = structure.normalize_tensors(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_structure = structure.Structure.from_value(self._initial_state) # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_structure=structure.NestedStructure( (self._state_structure, input_dataset._element_structure)), # pylint: disable=protected-access add_to_graph=False) if not ( isinstance(wrapped_func.output_types, collections.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. new_state_classes, output_classes = wrapped_func.output_classes old_state_classes = self._state_structure._to_legacy_output_classes() # pylint: disable=protected-access for new_state_class, old_state_class in zip( nest.flatten(new_state_classes), nest.flatten(old_state_classes)): if not issubclass(new_state_class, old_state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (old_state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, output_types = wrapped_func.output_types old_state_types = self._state_structure._to_legacy_output_types() # pylint: disable=protected-access for new_state_type, old_state_type in zip( nest.flatten(new_state_types), nest.flatten(old_state_types)): if new_state_type != old_state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (old_state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, output_shapes = wrapped_func.output_shapes old_state_shapes = self._state_structure._to_legacy_output_shapes() # pylint: disable=protected-access self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) flat_state_shapes = nest.flatten(old_state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # TODO(b/110122868): Support a "most specific compatible structure" # method for combining structures, to avoid using legacy structures # in this method. self._state_structure = structure.convert_legacy_structure( old_state_types, nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), old_state_classes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access variant_tensor = gen_experimental_dataset_ops.experimental_scan_dataset( self._input_dataset._variant_tensor, self._state_structure._to_tensor_list(self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **dataset_ops.flat_structure(self)) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to(output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) if shared_name is None: shared_name = "" if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=structure.get_flat_tensor_types( output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) else: iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=structure.get_flat_tensor_types(output_structure), output_shapes=structure.get_flat_tensor_shapes( output_structure)) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def __init__(self, input_dataset, functions, ratio_numerator=1, ratio_denominator=1, num_elements_per_branch=None): """Chooses the fastest of some dataset functions. Given dataset functions that take input_dataset as input and output another dataset, produces elements as quickly as the fastest of these output datasets. Note that datasets in the dataset functions are assumed to be stateless, and the iterators created by the functions' output datasets will, given the same input elements, all produce the same output elements. Datasets in the functions are also expected to iterate over the input dataset at most once. The violation of these conditions may lead to undefined behavior. For example: ```python dataset = tf.data.Dataset.range(100) dataset = _ChooseFastestDataset( dataset, [ lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10), lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1])) ], ratio=10, num_elements_per_branch=10 ) ``` The resulting dataset will produce elements equivalent to `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))` Note that the first `num_elements_per_branch` iterations may be slower due to the overhead of dynamically picking the fastest dataset. Namely, for these iterations, the dataset will produce elements from any of branches to determine which input is the fastest. For all subsequent iterations, that input will be used. Args: input_dataset: A `Dataset` that can be used as input to `functions`. functions: A list of callables, each of which takes a `Dataset` as input and returns a `Dataset`. ratio_numerator: The numerator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_numerator should be 10. ratio_denominator: The denominator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_denominator should be 1. num_elements_per_branch: The number of elements to get from each branch before deciding which dataset is fastest. In the first len(functions) * num_elements_per_branch iterations, the dataset will call from one of the branches, and update its knowledge of which input is the fastest. Note that (num_elements_per_branch * ratio) is expected to be an integer. Returns: A `Dataset` that has the same elements the inputs. """ nested_structure = structure_lib.NestedStructure( dataset_ops.DatasetStructure( structure_lib.convert_legacy_structure( input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes))) self._funcs = [ dataset_ops.StructuredFunctionWrapper( f, "ChooseFastestV2", input_structure=nested_structure) for f in functions ] self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access self._captured_arguments = [] for f in self._funcs: self._captured_arguments.extend(f.function.captured_inputs) self._capture_lengths = [ len(f.function.captured_inputs) for f in self._funcs ] if ratio_numerator <= 0 or ratio_denominator <= 0: raise ValueError("ratio must be positive.") if num_elements_per_branch is None: # Pick a sensible default based on `ratio_denominator` num_elements_per_branch = 10 * ratio_denominator variant_tensor = ( gen_experimental_dataset_ops.choose_fastest_branch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access ratio_numerator=ratio_numerator, ratio_denominator=ratio_denominator, other_arguments=self._captured_arguments, num_elements_per_branch=num_elements_per_branch, branches=[f.function for f in self._funcs], other_arguments_lengths=self._capture_lengths, **dataset_ops.flat_structure(self))) super(_ChooseFastestBranchDataset, self).__init__(input_dataset, variant_tensor)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" # Iteratively rerun the reduce function until reaching a fixed point on # `self._state_structure`. self._state_structure = self._init_func.output_structure state_types = self._init_func.output_types state_shapes = self._init_func.output_shapes state_classes = self._init_func.output_classes need_to_rerun = True while need_to_rerun: wrapped_func = structured_function.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=(self._state_structure, input_dataset.element_spec), add_to_graph=False) # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(wrapped_func.output_classes), nest.flatten(state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( f"Invalid `reducer`. The output class of the " f"`reducer.reduce_func` {wrapped_func.output_classes}, " f"does not match the class of the reduce state " f"{self._state_classes}.") # Extract and validate type information from the returned values. for new_state_type, state_type in zip( nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): if new_state_type != state_type: raise TypeError( f"Invalid `reducer`. The element types for the new state " f"{wrapped_func.output_types} do not match the element types " f"of the old state {self._init_func.output_types}.") # Extract shape information from the returned values. flat_state_shapes = nest.flatten(state_shapes) flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: state_shapes = nest.pack_sequence_as( self._init_func.output_shapes, weakened_state_shapes) self._state_structure = structure.convert_legacy_structure( state_types, state_shapes, state_classes) self._reduce_func = wrapped_func self._reduce_func.function.add_to_graph(ops.get_default_graph())
def __init__(self, input_dataset, initial_state, scan_func, use_default_device=None): """See `scan()` for details.""" self._input_dataset = input_dataset self._initial_state = structure.normalize_element(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_structure = structure.type_spec_from_value(self._initial_state) # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_structure=(self._state_structure, input_dataset.element_spec), add_to_graph=False) if not (isinstance(wrapped_func.output_types, collections_abc.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. new_state_classes, output_classes = wrapped_func.output_classes old_state_classes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access self._state_structure) for new_state_class, old_state_class in zip( nest.flatten(new_state_classes), nest.flatten(old_state_classes)): if not issubclass(new_state_class, old_state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (old_state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, output_types = wrapped_func.output_types old_state_types = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access self._state_structure) for new_state_type, old_state_type in zip( nest.flatten(new_state_types), nest.flatten(old_state_types)): if new_state_type != old_state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (old_state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, output_shapes = wrapped_func.output_shapes old_state_shapes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access self._state_structure) self._element_spec = structure.convert_legacy_structure( output_types, output_shapes, output_classes) flat_state_shapes = nest.flatten(old_state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # TODO(b/110122868): Support a "most specific compatible structure" # method for combining structures, to avoid using legacy structures # in this method. self._state_structure = structure.convert_legacy_structure( old_state_types, nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), old_state_classes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access if compat.forward_compatible(2019, 10, 15) or use_default_device is not None: variant_tensor = gen_experimental_dataset_ops.scan_dataset( self._input_dataset._variant_tensor, structure.to_tensor_list(self._state_structure, self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, use_default_device=use_default_device, **self._flat_structure) else: variant_tensor = gen_experimental_dataset_ops.scan_dataset( self._input_dataset._variant_tensor, structure.to_tensor_list(self._state_structure, self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **self._flat_structure) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" # Iteratively rerun the reduce function until reaching a fixed point on # `self._state_structure`. self._state_structure = self._init_func.output_structure state_types = self._init_func.output_types state_shapes = self._init_func.output_shapes state_classes = self._init_func.output_classes need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=structure.NestedStructure( (self._state_structure, input_dataset._element_structure)), # pylint: disable=protected-access add_to_graph=False) # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(wrapped_func.output_classes), nest.flatten(state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, wrapped_func.output_classes)) # Extract and validate type information from the returned values. for new_state_type, state_type in zip( nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): if new_state_type != state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._init_func.output_types, wrapped_func.output_types)) # Extract shape information from the returned values. flat_state_shapes = nest.flatten(state_shapes) flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: state_shapes = nest.pack_sequence_as( self._init_func.output_shapes, weakened_state_shapes) self._state_structure = structure.convert_legacy_structure( state_types, state_shapes, state_classes) self._reduce_func = wrapped_func self._reduce_func.function.add_to_graph(ops.get_default_graph())
def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue(expected_structure.is_compatible_with(actual_structure)) self.assertTrue(actual_structure.is_compatible_with(expected_structure))
def __init__(self, input_dataset, functions, ratio_numerator=1, ratio_denominator=1, num_elements_per_branch=None): """Chooses the fastest of some dataset functions. Given dataset functions that take input_dataset as input and output another dataset, produces elements as quickly as the fastest of these output datasets. Note that datasets in the dataset functions are assumed to be stateless, and the iterators created by the functions' output datasets will, given the same input elements, all produce the same output elements. Datasets in the functions are also expected to iterate over the input dataset at most once. The violation of these conditions may lead to undefined behavior. For example: ```python dataset = tf.data.Dataset.range(100) dataset = _ChooseFastestDataset( dataset, [ lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10), lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1])) ], ratio=10, num_elements_per_branch=10 ) ``` The resulting dataset will produce elements equivalent to `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))` Note that the first `num_elements_per_branch` iterations may be slower due to the overhead of dynamically picking the fastest dataset. Namely, for these iterations, the dataset will produce elements from any of branches to determine which input is the fastest. For all subsequent iterations, that input will be used. Args: input_dataset: A `Dataset` that can be used as input to `functions`. functions: A list of callables, each of which takes a `Dataset` as input and returns a `Dataset`. ratio_numerator: The numerator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_numerator should be 10. ratio_denominator: The denominator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_denominator should be 1. num_elements_per_branch: The number of elements to get from each branch before deciding which dataset is fastest. In the first len(functions) * num_elements_per_branch iterations, the dataset will call from one of the branches, and update its knowledge of which input is the fastest. Note that (num_elements_per_branch * ratio) is expected to be an integer. Returns: A `Dataset` that has the same elements the inputs. """ nested_structure = structure_lib.NestedStructure( dataset_ops.DatasetStructure( structure_lib.convert_legacy_structure( input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes))) self._funcs = [ dataset_ops.StructuredFunctionWrapper( f, "ChooseFastestV2", input_structure=nested_structure) for f in functions ] self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access self._captured_arguments = [] for f in self._funcs: self._captured_arguments.extend(f.function.captured_inputs) self._capture_lengths = [ len(f.function.captured_inputs) for f in self._funcs ] if ratio_numerator <= 0 or ratio_denominator <= 0: raise ValueError("ratio must be positive.") if num_elements_per_branch is None: # Pick a sensible default based on `ratio_denominator` num_elements_per_branch = 10 * ratio_denominator variant_tensor = ( gen_experimental_dataset_ops.choose_fastest_branch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access ratio_numerator=ratio_numerator, ratio_denominator=ratio_denominator, other_arguments=self._captured_arguments, num_elements_per_branch=num_elements_per_branch, branches=[f.function for f in self._funcs], other_arguments_lengths=self._capture_lengths, **dataset_ops.flat_structure(self))) super(_ChooseFastestBranchDataset, self).__init__(input_dataset, variant_tensor)
def from_structure(output_types, output_shapes=None, shared_name=None, output_classes=None): """Creates a new, uninitialized `Iterator` with the given structure. This iterator-constructing method can be used to create an iterator that is reusable with many different datasets. The returned iterator is not bound to a particular dataset, and it has no `initializer`. To initialize the iterator, run the operation returned by `Iterator.make_initializer(dataset)`. The following is an example ```python iterator = Iterator.from_structure(tf.int64, tf.TensorShape([])) dataset_range = Dataset.range(10) range_initializer = iterator.make_initializer(dataset_range) dataset_evens = dataset_range.filter(lambda x: x % 2 == 0) evens_initializer = iterator.make_initializer(dataset_evens) # Define a model based on the iterator; in this example, the model_fn # is expected to take scalar tf.int64 Tensors as input (see # the definition of 'iterator' above). prediction, loss = model_fn(iterator.get_next()) # Train for `num_epochs`, where for each epoch, we first iterate over # dataset_range, and then iterate over dataset_evens. for _ in range(num_epochs): # Initialize the iterator to `dataset_range` sess.run(range_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break # Initialize the iterator to `dataset_evens` sess.run(evens_initializer) while True: try: pred, loss_val = sess.run([prediction, loss]) except tf.errors.OutOfRangeError: break ``` Args: output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. shared_name: (Optional.) If non-empty, this iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server). output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. Raises: TypeError: If the structures of `output_shapes` and `output_types` are not the same. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure_lib.convert_legacy_structure( output_types, output_shapes, output_classes) if shared_name is None: shared_name = "" # pylint: disable=protected-access if compat.forward_compatible(2018, 8, 3): if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator( container="", shared_name=shared_name, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) # pylint: enable=protected-access return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def from_string_handle(string_handle, output_types, output_shapes=None, output_classes=None): """Creates a new, uninitialized `Iterator` based on the given handle. This method allows you to define a "feedable" iterator where you can choose between concrete iterators by feeding a value in a `tf.Session.run` call. In that case, `string_handle` would be a `tf.placeholder`, and you would feed it with the value of `tf.data.Iterator.string_handle` in each step. For example, if you had two iterators that marked the current position in a training dataset and a test dataset, you could choose which to use in each step as follows: ```python train_iterator = tf.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` Args: string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates to a handle produced by the `Iterator.string_handle()` method. output_types: A nested structure of `tf.DType` objects corresponding to each component of an element of this dataset. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects corresponding to each component of an element of this dataset. If omitted, each component will have an unconstrainted shape. output_classes: (Optional.) A nested structure of Python `type` objects corresponding to each component of an element of this iterator. If omitted, each component is assumed to be of type `tf.Tensor`. Returns: An `Iterator`. """ output_types = nest.map_structure(dtypes.as_dtype, output_types) if output_shapes is None: output_shapes = nest.map_structure( lambda _: tensor_shape.TensorShape(None), output_types) else: output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) output_structure = structure_lib.convert_legacy_structure( output_types, output_shapes, output_classes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) # pylint: disable=protected-access if compat.forward_compatible(2018, 8, 3): if _device_stack_is_empty(): with ops.device("/cpu:0"): iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) else: iterator_resource = gen_dataset_ops.iterator_from_string_handle( string_handle, output_types=output_structure._flat_types, output_shapes=output_structure._flat_shapes) # pylint: enable=protected-access return Iterator(iterator_resource, None, output_types, output_shapes, output_classes)
def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertEqual(actual_structure, expected_structure)
def __init__(self, dataset, output_types, output_shapes=None, output_classes=None, allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: * `dataset.output_types` must be the same as `output_types` module nesting. * Each shape in `dataset.output_shapes` must be compatible with each shape in `output_shapes` (if given). Note: This helper permits "unsafe casts" for shapes, equivalent to using `tf.Tensor.set_shape()` where domain-specific knowledge is available. Args: dataset: A `Dataset` object. output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. allow_unsafe_cast: (Optional.) If `True`, the caller may switch the reported output types and shapes of the restructured dataset, e.g. to switch a sparse tensor represented as `tf.variant` to its user-visible type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ self._input_dataset = dataset input_types = dataset_ops.get_legacy_output_types(dataset) if not allow_unsafe_cast: # Validate that the types are compatible. output_types = nest.map_structure(dtypes.as_dtype, output_types) flat_original_types = nest.flatten(input_types) flat_new_types = nest.flatten(output_types) if flat_original_types != flat_new_types: raise ValueError( "Dataset with output types %r cannot be restructured to have " "output types %r" % (dataset_ops.get_legacy_output_types(dataset), output_types)) input_shapes = dataset_ops.get_legacy_output_shapes(dataset) if output_shapes is None: # Inherit shapes from the original `dataset`. output_shapes = nest.pack_sequence_as( output_types, nest.flatten(input_shapes)) else: if not allow_unsafe_cast: # Validate that the shapes are compatible. nest.assert_same_structure(output_types, output_shapes) flat_original_shapes = nest.flatten(input_shapes) flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) for original_shape, new_shape in zip(flat_original_shapes, flat_new_shapes): if not original_shape.is_compatible_with(new_shape): raise ValueError( "Dataset with output shapes %r cannot be restructured to have " "incompatible output shapes %r" % (input_shapes, output_shapes)) output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) input_classes = dataset_ops.get_legacy_output_classes(dataset) if output_classes is None: # Inherit class types from the original `dataset`. output_classes = nest.pack_sequence_as( output_types, nest.flatten(input_classes)) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" # Iteratively rerun the reduce function until reaching a fixed point on # `self._state_structure`. self._state_structure = self._init_func.output_structure state_types = self._init_func.output_types state_shapes = self._init_func.output_shapes state_classes = self._init_func.output_classes need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=structure.NestedStructure( (self._state_structure, input_dataset._element_structure)), # pylint: disable=protected-access add_to_graph=False) # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(wrapped_func.output_classes), nest.flatten(state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, wrapped_func.output_classes)) # Extract and validate type information from the returned values. for new_state_type, state_type in zip( nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): if new_state_type != state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._init_func.output_types, wrapped_func.output_types)) # Extract shape information from the returned values. flat_state_shapes = nest.flatten(state_shapes) flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: state_shapes = nest.pack_sequence_as( self._init_func.output_shapes, weakened_state_shapes) self._state_structure = structure.convert_legacy_structure( state_types, state_shapes, state_classes) self._reduce_func = wrapped_func self._reduce_func.function.add_to_graph(ops.get_default_graph())