def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_calls_t = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t) if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). self._structure = self._map_func.output_structure._batch( # pylint: disable=protected-access tensor_util.constant_value(self._batch_size_t)) else: self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **dataset_ops.flat_structure(self)) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder, use_legacy_function=False): """See `Dataset.map()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset, use_legacy_function=use_legacy_function) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_calls_t = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t) if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). self._structure = self._map_func.output_structure._batch( # pylint: disable=protected-access tensor_util.constant_value(self._batch_size_t)) else: self._structure = self._map_func.output_structure._batch(None) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **dataset_ops.flat_structure(self)) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
def _as_variant_tensor(self): # pylint: disable=protected-access return ged_ops.experimental_map_and_batch_dataset( self._input_dataset._as_variant_tensor(), self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, **dataset_ops.flat_structure(structure=self._output_structure))
def _as_variant_tensor(self): # pylint: disable=protected-access return ged_ops.experimental_map_and_batch_dataset( self._input_dataset._as_variant_tensor(), self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, **dataset_ops.flat_structure(structure=self._output_structure))
def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder, use_legacy_function=False): """See `Dataset.map()` for details.""" self._input_dataset = input_dataset self._map_func = dataset_ops.StructuredFunctionWrapper( map_func, "tf.data.experimental.map_and_batch()", dataset=input_dataset, use_legacy_function=use_legacy_function) self._batch_size_t = ops.convert_to_tensor(batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_calls_t = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor(drop_remainder, dtype=dtypes.bool, name="drop_remainder") constant_drop_remainder = tensor_util.constant_value( self._drop_remainder_t) # pylint: disable=protected-access if constant_drop_remainder: # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically) # or `False` (explicitly retaining the remainder). # pylint: disable=g-long-lambda self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch( tensor_util.constant_value(self._batch_size_t)), self._map_func.output_structure) else: self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch(None), self._map_func.output_structure) # pylint: enable=protected-access if compat.forward_compatible(2019, 8, 3): variant_tensor = ged_ops.map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **self._flat_structure) else: variant_tensor = ged_ops.experimental_map_and_batch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._map_func.function.captured_inputs, f=self._map_func.function, batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, preserve_cardinality=True, **self._flat_structure) super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)