Пример #1
0
 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   return ged_ops.experimental_parallel_interleave_dataset(
       self._input_dataset._as_variant_tensor(),
       self._map_func.function.captured_inputs,
       self._cycle_length,
       self._block_length,
       self._sloppy,
       self._buffer_output_elements,
       self._prefetch_input_elements,
       f=self._map_func.function,
       **dataset_ops.flat_structure(structure=self._output_structure))
Пример #2
0
 def _as_variant_tensor(self):
     # pylint: disable=protected-access
     return ged_ops.experimental_parallel_interleave_dataset(
         self._input_dataset._as_variant_tensor(),
         self._map_func.function.captured_inputs,
         self._cycle_length,
         self._block_length,
         self._sloppy,
         self._buffer_output_elements,
         self._prefetch_input_elements,
         f=self._map_func.function,
         **dataset_ops.flat_structure(structure=self._output_structure))
Пример #3
0
 def __init__(self, input_dataset, map_func, cycle_length, block_length,
              sloppy, buffer_output_elements, prefetch_input_elements):
     """See `tf.data.experimental.parallel_interleave()` for details."""
     self._input_dataset = input_dataset
     self._map_func = dataset_ops.StructuredFunctionWrapper(
         map_func, self._transformation_name(), dataset=input_dataset)
     if not isinstance(self._map_func.output_structure,
                       dataset_ops.DatasetSpec):
         raise TypeError("`map_func` must return a `Dataset` object.")
     self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
     self._cycle_length = ops.convert_to_tensor(cycle_length,
                                                dtype=dtypes.int64,
                                                name="cycle_length")
     self._block_length = ops.convert_to_tensor(block_length,
                                                dtype=dtypes.int64,
                                                name="block_length")
     self._sloppy = ops.convert_to_tensor(sloppy,
                                          dtype=dtypes.bool,
                                          name="sloppy")
     self._buffer_output_elements = convert.optional_param_to_tensor(
         "buffer_output_elements",
         buffer_output_elements,
         argument_default=2 * block_length)
     self._prefetch_input_elements = convert.optional_param_to_tensor(
         "prefetch_input_elements",
         prefetch_input_elements,
         argument_default=2 * cycle_length)
     if compat.forward_compatible(2019, 8, 3):
         variant_tensor = ged_ops.parallel_interleave_dataset(
             self._input_dataset._variant_tensor,  # pylint: disable=protected-access
             self._map_func.function.captured_inputs,
             self._cycle_length,
             self._block_length,
             self._sloppy,
             self._buffer_output_elements,
             self._prefetch_input_elements,
             f=self._map_func.function,
             **self._flat_structure)
     else:
         variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
             self._input_dataset._variant_tensor,  # pylint: disable=protected-access
             self._map_func.function.captured_inputs,
             self._cycle_length,
             self._block_length,
             self._sloppy,
             self._buffer_output_elements,
             self._prefetch_input_elements,
             f=self._map_func.function,
             **self._flat_structure)
     super(ParallelInterleaveDataset,
           self).__init__(input_dataset, variant_tensor)
Пример #4
0
 def __init__(self, input_dataset, map_func, cycle_length, block_length,
              sloppy, buffer_output_elements, prefetch_input_elements):
   """See `tf.data.experimental.parallel_interleave()` for details."""
   self._input_dataset = input_dataset
   self._map_func = dataset_ops.StructuredFunctionWrapper(
       map_func, self._transformation_name(), dataset=input_dataset)
   if not isinstance(self._map_func.output_structure,
                     dataset_ops.DatasetStructure):
     raise TypeError("`map_func` must return a `Dataset` object.")
   self._structure = self._map_func.output_structure._element_structure  # pylint: disable=protected-access
   self._cycle_length = ops.convert_to_tensor(
       cycle_length, dtype=dtypes.int64, name="cycle_length")
   self._block_length = ops.convert_to_tensor(
       block_length, dtype=dtypes.int64, name="block_length")
   self._sloppy = ops.convert_to_tensor(
       sloppy, dtype=dtypes.bool, name="sloppy")
   self._buffer_output_elements = convert.optional_param_to_tensor(
       "buffer_output_elements",
       buffer_output_elements,
       argument_default=2 * block_length)
   self._prefetch_input_elements = convert.optional_param_to_tensor(
       "prefetch_input_elements",
       prefetch_input_elements,
       argument_default=2 * cycle_length)
   variant_tensor = ged_ops.experimental_parallel_interleave_dataset(
       self._input_dataset._variant_tensor,  # pylint: disable=protected-access
       self._map_func.function.captured_inputs,
       self._cycle_length,
       self._block_length,
       self._sloppy,
       self._buffer_output_elements,
       self._prefetch_input_elements,
       f=self._map_func.function,
       **dataset_ops.flat_structure(self))
   super(ParallelInterleaveDataset, self).__init__(input_dataset,
                                                   variant_tensor)