def _set_save_dataset_attributes(dataset, shard_func, path): """Sets parameters for SaveDatasetOp and SaveDatasetV2Op.""" if shard_func is None: use_shard_func = False shard_func = lambda *x: None # a dummy function that will not be used else: use_shard_func = True wrapped_func = structured_function.StructuredFunctionWrapper( shard_func, "save()", input_structure=dataset.element_spec, add_to_graph=False) encoded = nested_structure_coder.encode_structure(dataset.element_spec) gfile.MakeDirs(path) with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f: f.write(encoded.SerializeToString()) path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") shard_func = wrapped_func.function shard_func.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access dataset._apply_debug_options() return dataset, shard_func, use_shard_func, path,
def __init__(self, input_dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements, name=None): """See `tf.data.experimental.parallel_interleave()` for details.""" self._input_dataset = input_dataset self._map_func = structured_function.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec): raise TypeError( "The `map_func` argument must return a `Dataset` object. Got " f"{_get_type(self._map_func.output_structure)!r}.") self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64, name="block_length") self._buffer_output_elements = convert.optional_param_to_tensor( "buffer_output_elements", buffer_output_elements, argument_default=2 * block_length) self._prefetch_input_elements = convert.optional_param_to_tensor( "prefetch_input_elements", prefetch_input_elements, argument_default=2 * cycle_length) if sloppy is None: self._deterministic = "default" elif sloppy: self._deterministic = "false" else: self._deterministic = "true" self._metadata = dataset_metadata_pb2.Metadata() if name: self._metadata.name = dataset_ops._validate_and_encode(name) kwargs = self._flat_structure if name or compat.forward_compatible(2021, 9, 30): kwargs["metadata"] = self._metadata.SerializeToString() variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, self._cycle_length, self._block_length, self._buffer_output_elements, self._prefetch_input_elements, f=self._map_func.function, deterministic=self._deterministic, **kwargs) super(ParallelInterleaveDataset, self).__init__(input_dataset, variant_tensor)
def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" self._key_func = structured_function.StructuredFunctionWrapper( key_func, self._transformation_name(), dataset=input_dataset) if not self._key_func.output_structure.is_compatible_with( tensor_spec.TensorSpec([], dtypes.int64)): raise ValueError( f"Invalid `key_func`. Expected `key_func` to return a scalar " f"tf.int64 tensor, but instead `key_func` has output " f"types={self._key_func.output_types} " f"and shapes={self._key_func.output_shapes}.")
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder, use_legacy_function=False): self._input_dataset = input_dataset self._map_func = structured_function.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset, use_legacy_function=use_legacy_function) self._batch_size_t = ops.convert_to_tensor(batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_calls_t = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor(drop_remainder, dtype=dtypes.bool, name="drop_remainder") constant_drop_remainder = tensor_util.constant_value( self._drop_remainder_t) # pylint: disable=protected-access if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). # pylint: disable=g-long-lambda self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch( tensor_util.constant_value(self._batch_size_t)), self._map_func.output_structure) else: self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch(None), self._map_func.output_structure) # pylint: enable=protected-access variant_tensor = ged_ops.map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **self._flat_structure) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, map_func): """See `Dataset.flat_map()` for details.""" self._input_dataset = input_dataset self._map_func = structured_function.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset, defun_kwargs={"_executor": "SINGLE_THREADED_EXECUTOR"}) self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access variant_tensor = gen_dataset_ops.flat_map_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, **self._flat_structure) super(SingleThreadedFlatMapDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" self._input_dataset = input_dataset self._use_inter_op_parallelism = use_inter_op_parallelism self._map_func = structured_function.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset, defun_kwargs={"experimental_ints_on_device": True}) variant_tensor = ged_ops.experimental_map_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, use_inter_op_parallelism=self._use_inter_op_parallelism, **self._flat_structure) super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, path, element_spec=None, compression=None, reader_func=None): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._path = path if element_spec is None: if not context.executing_eagerly(): raise ValueError( "In graph mode the `element_spec` argument must be provided." ) with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f: encoded_spec = f.read() struct_pb = nested_structure_coder.struct_pb2.StructuredValue() struct_pb.ParseFromString(encoded_spec) coder = nested_structure_coder.StructureCoder() spec = coder.decode_proto(struct_pb) self._element_spec = spec else: self._element_spec = element_spec self._compression = compression self._reader_func = structured_function.StructuredFunctionWrapper( reader_func, "load()", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(self._element_spec))) variant_tensor = gen_experimental_dataset_ops.load_dataset( path, reader_func_other_args=self._reader_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, **self._flat_structure) super(_LoadDataset, self).__init__(variant_tensor)
def _make_finalize_func(self, finalize_func): """Make wrapping defun for finalize_func.""" self._finalize_func = structured_function.StructuredFunctionWrapper( finalize_func, self._transformation_name(), input_structure=self._state_structure)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" # Iteratively rerun the reduce function until reaching a fixed point on # `self._state_structure`. self._state_structure = self._init_func.output_structure state_types = self._init_func.output_types state_shapes = self._init_func.output_shapes state_classes = self._init_func.output_classes need_to_rerun = True while need_to_rerun: wrapped_func = structured_function.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=(self._state_structure, input_dataset.element_spec), add_to_graph=False) # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(wrapped_func.output_classes), nest.flatten(state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( f"Invalid `reducer`. The output class of the " f"`reducer.reduce_func` {wrapped_func.output_classes}, " f"does not match the class of the reduce state " f"{self._state_classes}.") # Extract and validate type information from the returned values. for new_state_type, state_type in zip( nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): if new_state_type != state_type: raise TypeError( f"Invalid `reducer`. The element types for the new state " f"{wrapped_func.output_types} do not match the element types " f"of the old state {self._init_func.output_types}.") # Extract shape information from the returned values. flat_state_shapes = nest.flatten(state_shapes) flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: state_shapes = nest.pack_sequence_as( self._init_func.output_shapes, weakened_state_shapes) self._state_structure = structure.convert_legacy_structure( state_types, state_shapes, state_classes) self._reduce_func = wrapped_func self._reduce_func.function.add_to_graph(ops.get_default_graph())
def _make_init_func(self, init_func): """Make wrapping defun for init_func.""" self._init_func = structured_function.StructuredFunctionWrapper( init_func, self._transformation_name(), input_structure=tensor_spec.TensorSpec([], dtypes.int64))