def __init__(self, input_dataset, path, shard_func, compression=None, reader_func=None, pending_snapshot_expiry_seconds=None, use_legacy_function=False): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._input_dataset = input_dataset self._path = path self._compression = compression self._reader_func = dataset_ops.StructuredFunctionWrapper( reader_func, self._transformation_name() + ".reader_func", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(input_dataset.element_spec)), use_legacy_function=use_legacy_function) self._shard_func = dataset_ops.StructuredFunctionWrapper( shard_func, self._transformation_name() + ".shard_func", dataset=input_dataset, use_legacy_function=use_legacy_function) if ((not self._shard_func.output_structure.is_compatible_with( tensor_spec.TensorSpec([], dtypes.int32))) and (not self._shard_func.output_structure.is_compatible_with( tensor_spec.TensorSpec([], dtypes.int64)))): raise TypeError( "shard_func must return a 0-dimension tensor containing an int." ) variant_tensor = ged_ops.snapshot_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access path, self._reader_func.function.captured_inputs, self._shard_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, shard_func=self._shard_func.function, **self._flat_structure) super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" # Iteratively rerun the reduce function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_classes=(self._state_classes, input_dataset.output_classes), input_shapes=(self._state_shapes, input_dataset.output_shapes), input_types=(self._state_types, input_dataset.output_types), add_to_graph=False) # Extract and validate class information from the returned values. for new_state_class, state_class in zip( nest.flatten(wrapped_func.output_classes), nest.flatten(self._state_classes)): if not issubclass(new_state_class, state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (self._state_classes, wrapped_func.output_classes)) # Extract and validate type information from the returned values. for new_state_type, state_type in zip( nest.flatten(wrapped_func.output_types), nest.flatten(self._state_types)): if new_state_type != state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (self._state_types, wrapped_func.output_types)) # Extract shape information from the returned values. flat_state_shapes = nest.flatten(self._state_shapes) flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: self._state_shapes = nest.pack_sequence_as( self._state_shapes, weakened_state_shapes) self._reduce_func = wrapped_func.function self._reduce_func.add_to_graph(ops.get_default_graph())
def __init__(self, path, element_spec=None, compression=None, reader_func=None): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._path = path if element_spec is None: with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f: encoded_spec = f.read() struct_pb = nested_structure_coder.struct_pb2.StructuredValue() struct_pb.ParseFromString(encoded_spec) coder = nested_structure_coder.StructureCoder() spec = coder.decode_proto(struct_pb) self._element_spec = spec else: self._element_spec = element_spec self._compression = compression self._reader_func = dataset_ops.StructuredFunctionWrapper( reader_func, "load()", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(self._element_spec))) variant_tensor = gen_experimental_dataset_ops.load_dataset( path, reader_func_other_args=self._reader_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, **self._flat_structure) super(_LoadDataset, self).__init__(variant_tensor)
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder, use_legacy_function=False): """See `Dataset.map()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset, use_legacy_function=use_legacy_function) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_calls_t = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t) if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). self._structure = self._map_func.output_structure._batch( # pylint: disable=protected-access tensor_util.constant_value(self._batch_size_t)) else: self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, **dataset_ops.flat_structure(self)) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements): """See `tf.data.experimental.parallel_interleave()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec): raise TypeError("`map_func` must return a `Dataset` object.") self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access self._cycle_length = ops.convert_to_tensor( cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor( block_length, dtype=dtypes.int64, name="block_length") self._sloppy = ops.convert_to_tensor( sloppy, dtype=dtypes.bool, name="sloppy") self._buffer_output_elements = convert.optional_param_to_tensor( "buffer_output_elements", buffer_output_elements, argument_default=2 * block_length) self._prefetch_input_elements = convert.optional_param_to_tensor( "prefetch_input_elements", prefetch_input_elements, argument_default=2 * cycle_length) variant_tensor = ged_ops.parallel_interleave_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, self._cycle_length, self._block_length, self._sloppy, self._buffer_output_elements, self._prefetch_input_elements, f=self._map_func.function, **self._flat_structure) super(ParallelInterleaveDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, path, element_spec, compression=None, reader_func=None): if reader_func is None: reader_func = lambda datasets: datasets.interleave( # pylint:disable=g-long-lambda lambda x: x, cycle_length=multiprocessing.cpu_count(), num_parallel_calls=dataset_ops.AUTOTUNE) self._path = path self._element_spec = element_spec self._compression = compression self._reader_func = dataset_ops.StructuredFunctionWrapper( reader_func, "load()", # Dataset of datasets of input elements input_structure=dataset_ops.DatasetSpec( dataset_ops.DatasetSpec(element_spec))) variant_tensor = gen_experimental_dataset_ops.load_dataset( path, reader_func_other_args=self._reader_func.function.captured_inputs, compression=compression, reader_func=self._reader_func.function, **self._flat_structure) super(_LoadDataset, self).__init__(variant_tensor)
def __init__(self, input_dataset, predicate): """See `take_while()` for details.""" self._input_dataset = input_dataset wrapped_func = dataset_ops.StructuredFunctionWrapper( predicate, "tf.data.experimental.take_while()", dataset=self._input_dataset) if not wrapped_func.output_structure.is_compatible_with( structure_lib.TensorStructure(dtypes.bool, [])): raise ValueError( "`predicate` must return a scalar boolean tensor.") self._predicate = wrapped_func if compat.forward_compatible(2019, 8, 3): var_tensor = gen_experimental_dataset_ops.take_while_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access other_arguments=self._predicate.function.captured_inputs, predicate=self._predicate.function, **self._flat_structure) else: var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access other_arguments=self._predicate.function.captured_inputs, predicate=self._predicate.function, **self._flat_structure) super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
def _set_save_dataset_attributes(dataset, shard_func, path): """Sets parameters for SaveDatasetOp and SaveDatasetV2Op.""" if shard_func is None: use_shard_func = False shard_func = lambda *x: None # a dummy function that will not be used else: use_shard_func = True wrapped_func = dataset_ops.StructuredFunctionWrapper( shard_func, "save()", input_structure=dataset.element_spec, add_to_graph=False) coder = nested_structure_coder.StructureCoder() encoded = coder.encode_structure(dataset.element_spec) gfile.MakeDirs(path) with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f: f.write(encoded.SerializeToString()) path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path") shard_func = wrapped_func.function shard_func.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access dataset._apply_debug_options() return dataset, shard_func, use_shard_func, path,
def __init__(self, input_dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements, name=None): """See `tf.data.experimental.parallel_interleave()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec): raise TypeError( "The `map_func` argument must return a `Dataset` object. Got " f"{_get_type(self._map_func.output_structure)!r}.") self._element_spec = self._map_func.output_structure._element_spec # pylint: disable=protected-access self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64, name="block_length") self._buffer_output_elements = convert.optional_param_to_tensor( "buffer_output_elements", buffer_output_elements, argument_default=2 * block_length) self._prefetch_input_elements = convert.optional_param_to_tensor( "prefetch_input_elements", prefetch_input_elements, argument_default=2 * cycle_length) if sloppy is None: self._deterministic = "default" elif sloppy: self._deterministic = "false" else: self._deterministic = "true" self._metadata = dataset_metadata_pb2.Metadata() if name: self._metadata.name = dataset_ops._validate_and_encode(name) kwargs = self._flat_structure if name or compat.forward_compatible(2021, 9, 30): kwargs["metadata"] = self._metadata.SerializeToString() variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, self._cycle_length, self._block_length, self._buffer_output_elements, self._prefetch_input_elements, f=self._map_func.function, deterministic=self._deterministic, **kwargs) super(ParallelInterleaveDataset, self).__init__(input_dataset, variant_tensor)
def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" self._key_func = dataset_ops.StructuredFunctionWrapper( key_func, self._transformation_name(), dataset=input_dataset) if not self._key_func.output_structure.is_compatible_with( structure.TensorStructure(dtypes.int64, [])): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (self._key_func.output_types, self._key_func.output_shapes))
def _make_finalize_func(self, finalize_func): """Make wrapping Defun for finalize_func.""" wrapped_func = dataset_ops.StructuredFunctionWrapper( finalize_func, "tf.contrib.data.group_by_reducer()", input_classes=self._state_classes, input_shapes=self._state_shapes, input_types=self._state_types) self._finalize_func = wrapped_func.function self._output_classes = wrapped_func.output_classes self._output_shapes = wrapped_func.output_shapes self._output_types = wrapped_func.output_types
def _make_init_func(self, init_func): """Make wrapping Defun for init_func.""" wrapped_func = dataset_ops.StructuredFunctionWrapper( init_func, "tf.contrib.data.group_by_reducer()", input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), input_types=dtypes.int64) self._init_func = wrapped_func.function self._state_classes = wrapped_func.output_classes self._state_shapes = wrapped_func.output_shapes self._state_types = wrapped_func.output_types
def __init__( self, input_dataset, map_func, cycle_length, block_length, num_parallel_calls, deterministic, buffer_output_elements=None, # backward compatibility with TF2.0 prefetch_input_elements=None): # backward compatibility with TF2.0 """See `Dataset.interleave()` for details.""" self._input_dataset = input_dataset self._map_func = ds.StructuredFunctionWrapper( map_func._func, self._transformation_name(), input_dataset) self._cycle_length = cycle_length self._block_length = block_length self._buffer_output_elements = buffer_output_elements self._prefetch_input_elements = prefetch_input_elements self._num_parallel_calls = num_parallel_calls self._deterministic = deterministic if (buffer_output_elements and buffer_output_elements != ds_helpers.autotune_flag()) or \ (prefetch_input_elements and prefetch_input_elements != ds_helpers.autotune_flag()): variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4( input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, # pylint: disable=protected-access self._cycle_length, self._block_length, self._buffer_output_elements, self._prefetch_input_elements, self._num_parallel_calls, f=self._map_func.function, deterministic=deterministic, **self._flat_structure) elif deterministic != "default": variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3( input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, # pylint: disable=protected-access self._cycle_length, self._block_length, self._num_parallel_calls, f=self._map_func.function, deterministic=deterministic_string, **self._flat_structure) else: variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, # pylint: disable=protected-access self._cycle_length, self._block_length, self._num_parallel_calls, f=self._map_func.function, **self._flat_structure) super(TntParallelInterleaveDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, map_func): """See `map_x_dataset()` for details.""" super(_MapXDataset, self).__init__(input_dataset) self._input_dataset = input_dataset wrapped_func = dataset_ops.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset) self._output_classes = wrapped_func.output_classes self._output_shapes = wrapped_func.output_shapes self._output_types = wrapped_func.output_types self._map_func = wrapped_func.function
def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" def key_func_wrapper(*args): return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) self._key_func = dataset_ops.StructuredFunctionWrapper( key_func_wrapper, self._transformation_name(), dataset=input_dataset) if not self._key_func.output_structure.is_compatible_with( structure.TensorStructure(dtypes.int64, [])): raise ValueError( "`key_func` must return a single tf.int64 scalar tensor.")
def __init__(self, input_dataset, predicate, use_legacy_function=False): """See `Dataset.filter()` for details.""" self._input_dataset = input_dataset self._predicate = ds.StructuredFunctionWrapper( predicate._func, self._transformation_name(), input_dataset) variant_tensor = gen_dataset_ops.filter_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access other_arguments=self._predicate.function.captured_inputs, predicate=self._predicate.function, **self._flat_structure) super(TntFilterDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" super(_MapOnGpuDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._use_inter_op_parallelism = use_inter_op_parallelism self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset, defun_kwargs={"experimental_ints_on_device": True})
def __init__(self, input_dataset, map_func): """See `Dataset.flat_map()` for details.""" self._input_dataset = input_dataset self._map_func = ds.StructuredFunctionWrapper( map_func._func, self._transformation_name(), input_dataset) variant_tensor = gen_dataset_ops.flat_map_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, **self._flat_structure) super(TntFlatMapDataset, self).__init__(input_dataset, variant_tensor)
def _make_init_func(self, init_func): """Make wrapping defun for init_func.""" wrapped_func = dataset_ops.StructuredFunctionWrapper( init_func, self._transformation_name(), input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), input_types=dtypes.int64) self._init_func = wrapped_func.function self._state_classes = wrapped_func.output_classes self._state_shapes = wrapped_func.output_shapes self._state_types = wrapped_func.output_types
def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" wrapped_func = dataset_ops.StructuredFunctionWrapper( key_func, self._transformation_name(), dataset=input_dataset) if not (wrapped_func.output_types == dtypes.int64 and wrapped_func.output_shapes.is_compatible_with( tensor_shape.scalar())): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (wrapped_func.output_types, wrapped_func.output_shapes)) self._key_func = wrapped_func.function
def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" def key_func_wrapper(*args): return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) wrapped_func = dataset_ops.StructuredFunctionWrapper( key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset) if not ( wrapped_func.output_types == dtypes.int64 and wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): raise ValueError( "`key_func` must return a single tf.int64 scalar tensor.") self._key_func = wrapped_func.function
def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" self._key_func = dataset_ops.StructuredFunctionWrapper( key_func, self._transformation_name(), dataset=input_dataset) if not self._key_func.output_structure.is_compatible_with( tensor_spec.TensorSpec([], dtypes.int64)): raise ValueError( f"Invalid `key_func`. Expected `key_func` to return a scalar " f"tf.int64 tensor, but instead `key_func` has output " f"types={self._key_func.output_types} " f"and shapes={self._key_func.output_shapes}." )
def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" wrapped_func = dataset_ops.StructuredFunctionWrapper( key_func, "tf.data.experimental.group_by_reducer()", input_dataset) if not ( wrapped_func.output_types == dtypes.int64 and wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (wrapped_func.output_types, wrapped_func.output_shapes)) self._key_func = wrapped_func.function
def _make_finalize_func(self, finalize_func): """Make wrapping defun for finalize_func.""" wrapped_func = dataset_ops.StructuredFunctionWrapper( finalize_func, self._transformation_name(), input_classes=self._state_classes, input_shapes=self._state_shapes, input_types=self._state_types) self._finalize_func = wrapped_func.function self._output_classes = wrapped_func.output_classes self._output_shapes = wrapped_func.output_shapes self._output_types = wrapped_func.output_types
def _make_window_size_func(self, window_size_func): """Make wrapping defun for window_size_func.""" def window_size_func_wrapper(key): return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) self._window_size_func = dataset_ops.StructuredFunctionWrapper( window_size_func_wrapper, self._transformation_name(), input_structure=structure.TensorStructure(dtypes.int64, [])) if not self._window_size_func.output_structure.is_compatible_with( structure.TensorStructure(dtypes.int64, [])): raise ValueError( "`window_size_func` must return a single tf.int64 scalar tensor.")
def __init__(self, input_dataset, map_func): """See `map_x_dataset()` for details.""" super(_MapXDataset, self).__init__() self._input_dataset = input_dataset wrapped_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.contrib.data.map_x_dataset()", input_dataset, experimental_nested_dataset_support=True) self._output_classes = wrapped_func.output_classes self._output_shapes = wrapped_func.output_shapes self._output_types = wrapped_func.output_types self._map_func = wrapped_func.function
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" nested_dataset = dataset_ops.DatasetSpec(input_dataset.element_spec) input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset) self._reduce_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=input_structure) if not isinstance(self._reduce_func.output_structure, dataset_ops.DatasetSpec): raise TypeError("`reduce_func` must return a `Dataset` object.") # pylint: disable=protected-access self._element_spec = (self._reduce_func.output_structure._element_spec)
def _make_window_size_func(self, window_size_func): """Make wrapping Defun for window_size_func.""" def window_size_func_wrapper(key): return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) wrapped_func = dataset_ops.StructuredFunctionWrapper( window_size_func_wrapper, "tf.contrib.data.group_by_window()", input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), input_types=dtypes.int64) if not ( wrapped_func.output_types == dtypes.int64 and wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): raise ValueError( "`window_size_func` must return a single tf.int64 scalar tensor.") self._window_size_func = wrapped_func.function
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder, use_legacy_function=False): self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset, use_legacy_function=use_legacy_function) self._batch_size_t = ops.convert_to_tensor(batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_calls_t = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor(drop_remainder, dtype=dtypes.bool, name="drop_remainder") constant_drop_remainder = tensor_util.constant_value( self._drop_remainder_t) # pylint: disable=protected-access if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). # pylint: disable=g-long-lambda self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch( tensor_util.constant_value(self._batch_size_t)), self._map_func.output_structure) else: self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch(None), self._map_func.output_structure) # pylint: enable=protected-access variant_tensor = ged_ops.map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **self._flat_structure) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, map_func): """See `Dataset.flat_map()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, self._transformation_name(), dataset=input_dataset, defun_kwargs={"_executor": "SINGLE_THREADED_EXECUTOR"}) self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access variant_tensor = gen_dataset_ops.flat_map_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, **self._flat_structure) super(SingleThreadedFlatMapDataset, self).__init__(input_dataset, variant_tensor)