示例#1
0
    def __init__(self,
                 input_dataset,
                 path,
                 shard_func,
                 compression=None,
                 reader_func=None,
                 pending_snapshot_expiry_seconds=None,
                 use_legacy_function=False):

        if reader_func is None:
            reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
                lambda x: x,
                cycle_length=multiprocessing.cpu_count(),
                num_parallel_calls=dataset_ops.AUTOTUNE)

        self._input_dataset = input_dataset
        self._path = path
        self._compression = compression

        self._reader_func = dataset_ops.StructuredFunctionWrapper(
            reader_func,
            self._transformation_name() + ".reader_func",
            # Dataset of datasets of input elements
            input_structure=dataset_ops.DatasetSpec(
                dataset_ops.DatasetSpec(input_dataset.element_spec)),
            use_legacy_function=use_legacy_function)
        self._shard_func = dataset_ops.StructuredFunctionWrapper(
            shard_func,
            self._transformation_name() + ".shard_func",
            dataset=input_dataset,
            use_legacy_function=use_legacy_function)

        if ((not self._shard_func.output_structure.is_compatible_with(
                tensor_spec.TensorSpec([], dtypes.int32)))
                and (not self._shard_func.output_structure.is_compatible_with(
                    tensor_spec.TensorSpec([], dtypes.int64)))):
            raise TypeError(
                "shard_func must return a 0-dimension tensor containing an int."
            )

        variant_tensor = ged_ops.snapshot_dataset_v2(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            path,
            self._reader_func.function.captured_inputs,
            self._shard_func.function.captured_inputs,
            compression=compression,
            reader_func=self._reader_func.function,
            shard_func=self._shard_func.function,
            **self._flat_structure)
        super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
示例#2
0
    def _make_reduce_func(self, reduce_func, input_dataset):
        """Make wrapping defun for reduce_func."""

        # Iteratively rerun the reduce function until reaching a fixed point on
        # `self._state_shapes`.
        need_to_rerun = True
        while need_to_rerun:

            wrapped_func = dataset_ops.StructuredFunctionWrapper(
                reduce_func,
                self._transformation_name(),
                input_classes=(self._state_classes,
                               input_dataset.output_classes),
                input_shapes=(self._state_shapes, input_dataset.output_shapes),
                input_types=(self._state_types, input_dataset.output_types),
                add_to_graph=False)

            # Extract and validate class information from the returned values.
            for new_state_class, state_class in zip(
                    nest.flatten(wrapped_func.output_classes),
                    nest.flatten(self._state_classes)):
                if not issubclass(new_state_class, state_class):
                    raise TypeError(
                        "The element classes for the new state must match the initial "
                        "state. Expected %s; got %s." %
                        (self._state_classes, wrapped_func.output_classes))

            # Extract and validate type information from the returned values.
            for new_state_type, state_type in zip(
                    nest.flatten(wrapped_func.output_types),
                    nest.flatten(self._state_types)):
                if new_state_type != state_type:
                    raise TypeError(
                        "The element types for the new state must match the initial "
                        "state. Expected %s; got %s." %
                        (self._state_types, wrapped_func.output_types))

            # Extract shape information from the returned values.
            flat_state_shapes = nest.flatten(self._state_shapes)
            flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
            weakened_state_shapes = [
                original.most_specific_compatible_shape(new) for original, new
                in zip(flat_state_shapes, flat_new_state_shapes)
            ]

            need_to_rerun = False
            for original_shape, weakened_shape in zip(flat_state_shapes,
                                                      weakened_state_shapes):
                if original_shape.ndims is not None and (
                        weakened_shape.ndims is None or
                        original_shape.as_list() != weakened_shape.as_list()):
                    need_to_rerun = True
                    break

            if need_to_rerun:
                self._state_shapes = nest.pack_sequence_as(
                    self._state_shapes, weakened_state_shapes)

        self._reduce_func = wrapped_func.function
        self._reduce_func.add_to_graph(ops.get_default_graph())
示例#3
0
  def __init__(self, path, element_spec=None, compression=None,
               reader_func=None):

    if reader_func is None:
      reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
          lambda x: x,
          cycle_length=multiprocessing.cpu_count(),
          num_parallel_calls=dataset_ops.AUTOTUNE)

    self._path = path
    if element_spec is None:
      with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "rb") as f:
        encoded_spec = f.read()
      struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
      struct_pb.ParseFromString(encoded_spec)
      coder = nested_structure_coder.StructureCoder()
      spec = coder.decode_proto(struct_pb)
      self._element_spec = spec
    else:
      self._element_spec = element_spec
    self._compression = compression
    self._reader_func = dataset_ops.StructuredFunctionWrapper(
        reader_func,
        "load()",
        # Dataset of datasets of input elements
        input_structure=dataset_ops.DatasetSpec(
            dataset_ops.DatasetSpec(self._element_spec)))

    variant_tensor = gen_experimental_dataset_ops.load_dataset(
        path,
        reader_func_other_args=self._reader_func.function.captured_inputs,
        compression=compression,
        reader_func=self._reader_func.function,
        **self._flat_structure)
    super(_LoadDataset, self).__init__(variant_tensor)
示例#4
0
  def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
               drop_remainder, use_legacy_function=False):
    """See `Dataset.map()` for details."""
    self._input_dataset = input_dataset

    self._map_func = dataset_ops.StructuredFunctionWrapper(
        map_func,
        "tf.data.experimental.map_and_batch()",
        dataset=input_dataset,
        use_legacy_function=use_legacy_function)
    self._batch_size_t = ops.convert_to_tensor(
        batch_size, dtype=dtypes.int64, name="batch_size")
    self._num_parallel_calls_t = ops.convert_to_tensor(
        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
    self._drop_remainder_t = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")

    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t)
    if constant_drop_remainder:
      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
      # or `False` (explicitly retaining the remainder).
      self._structure = self._map_func.output_structure._batch(  # pylint: disable=protected-access
          tensor_util.constant_value(self._batch_size_t))
    else:
      self._structure = self._map_func.output_structure._batch(None)  # pylint: disable=protected-access
    variant_tensor = ged_ops.experimental_map_and_batch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        self._map_func.function.captured_inputs,
        f=self._map_func.function,
        batch_size=self._batch_size_t,
        num_parallel_calls=self._num_parallel_calls_t,
        drop_remainder=self._drop_remainder_t,
        **dataset_ops.flat_structure(self))
    super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
示例#5
0
 def __init__(self, input_dataset, map_func, cycle_length, block_length,
              sloppy, buffer_output_elements, prefetch_input_elements):
   """See `tf.data.experimental.parallel_interleave()` for details."""
   self._input_dataset = input_dataset
   self._map_func = dataset_ops.StructuredFunctionWrapper(
       map_func, self._transformation_name(), dataset=input_dataset)
   if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec):
     raise TypeError("`map_func` must return a `Dataset` object.")
   self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
   self._cycle_length = ops.convert_to_tensor(
       cycle_length, dtype=dtypes.int64, name="cycle_length")
   self._block_length = ops.convert_to_tensor(
       block_length, dtype=dtypes.int64, name="block_length")
   self._sloppy = ops.convert_to_tensor(
       sloppy, dtype=dtypes.bool, name="sloppy")
   self._buffer_output_elements = convert.optional_param_to_tensor(
       "buffer_output_elements",
       buffer_output_elements,
       argument_default=2 * block_length)
   self._prefetch_input_elements = convert.optional_param_to_tensor(
       "prefetch_input_elements",
       prefetch_input_elements,
       argument_default=2 * cycle_length)
   variant_tensor = ged_ops.parallel_interleave_dataset(
       self._input_dataset._variant_tensor,  # pylint: disable=protected-access
       self._map_func.function.captured_inputs,
       self._cycle_length,
       self._block_length,
       self._sloppy,
       self._buffer_output_elements,
       self._prefetch_input_elements,
       f=self._map_func.function,
       **self._flat_structure)
   super(ParallelInterleaveDataset, self).__init__(input_dataset,
                                                   variant_tensor)
示例#6
0
    def __init__(self, path, element_spec, compression=None, reader_func=None):

        if reader_func is None:
            reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
                lambda x: x,
                cycle_length=multiprocessing.cpu_count(),
                num_parallel_calls=dataset_ops.AUTOTUNE)

        self._path = path
        self._element_spec = element_spec
        self._compression = compression

        self._reader_func = dataset_ops.StructuredFunctionWrapper(
            reader_func,
            "load()",
            # Dataset of datasets of input elements
            input_structure=dataset_ops.DatasetSpec(
                dataset_ops.DatasetSpec(element_spec)))

        variant_tensor = gen_experimental_dataset_ops.load_dataset(
            path,
            reader_func_other_args=self._reader_func.function.captured_inputs,
            compression=compression,
            reader_func=self._reader_func.function,
            **self._flat_structure)
        super(_LoadDataset, self).__init__(variant_tensor)
    def __init__(self, input_dataset, predicate):
        """See `take_while()` for details."""

        self._input_dataset = input_dataset
        wrapped_func = dataset_ops.StructuredFunctionWrapper(
            predicate,
            "tf.data.experimental.take_while()",
            dataset=self._input_dataset)

        if not wrapped_func.output_structure.is_compatible_with(
                structure_lib.TensorStructure(dtypes.bool, [])):
            raise ValueError(
                "`predicate` must return a scalar boolean tensor.")

        self._predicate = wrapped_func
        if compat.forward_compatible(2019, 8, 3):
            var_tensor = gen_experimental_dataset_ops.take_while_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                other_arguments=self._predicate.function.captured_inputs,
                predicate=self._predicate.function,
                **self._flat_structure)
        else:
            var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                other_arguments=self._predicate.function.captured_inputs,
                predicate=self._predicate.function,
                **self._flat_structure)
        super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
示例#8
0
def _set_save_dataset_attributes(dataset, shard_func, path):
  """Sets parameters for SaveDatasetOp and SaveDatasetV2Op."""
  if shard_func is None:
    use_shard_func = False
    shard_func = lambda *x: None  # a dummy function that will not be used
  else:
    use_shard_func = True

  wrapped_func = dataset_ops.StructuredFunctionWrapper(
      shard_func,
      "save()",
      input_structure=dataset.element_spec,
      add_to_graph=False)

  coder = nested_structure_coder.StructureCoder()
  encoded = coder.encode_structure(dataset.element_spec)
  gfile.MakeDirs(path)
  with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f:
    f.write(encoded.SerializeToString())

  path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
  shard_func = wrapped_func.function
  shard_func.add_to_graph(ops.get_default_graph())

  # pylint: disable=protected-access
  dataset._apply_debug_options()
  return dataset, shard_func, use_shard_func, path,
示例#9
0
    def __init__(self,
                 input_dataset,
                 map_func,
                 cycle_length,
                 block_length,
                 sloppy,
                 buffer_output_elements,
                 prefetch_input_elements,
                 name=None):
        """See `tf.data.experimental.parallel_interleave()` for details."""
        self._input_dataset = input_dataset
        self._map_func = dataset_ops.StructuredFunctionWrapper(
            map_func, self._transformation_name(), dataset=input_dataset)
        if not isinstance(self._map_func.output_structure,
                          dataset_ops.DatasetSpec):
            raise TypeError(
                "The `map_func` argument must return a `Dataset` object. Got "
                f"{_get_type(self._map_func.output_structure)!r}.")
        self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
        self._cycle_length = ops.convert_to_tensor(cycle_length,
                                                   dtype=dtypes.int64,
                                                   name="cycle_length")
        self._block_length = ops.convert_to_tensor(block_length,
                                                   dtype=dtypes.int64,
                                                   name="block_length")
        self._buffer_output_elements = convert.optional_param_to_tensor(
            "buffer_output_elements",
            buffer_output_elements,
            argument_default=2 * block_length)
        self._prefetch_input_elements = convert.optional_param_to_tensor(
            "prefetch_input_elements",
            prefetch_input_elements,
            argument_default=2 * cycle_length)
        if sloppy is None:
            self._deterministic = "default"
        elif sloppy:
            self._deterministic = "false"
        else:
            self._deterministic = "true"
        self._metadata = dataset_metadata_pb2.Metadata()
        if name:
            self._metadata.name = dataset_ops._validate_and_encode(name)
        kwargs = self._flat_structure
        if name or compat.forward_compatible(2021, 9, 30):
            kwargs["metadata"] = self._metadata.SerializeToString()

        variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._map_func.function.captured_inputs,
            self._cycle_length,
            self._block_length,
            self._buffer_output_elements,
            self._prefetch_input_elements,
            f=self._map_func.function,
            deterministic=self._deterministic,
            **kwargs)
        super(ParallelInterleaveDataset,
              self).__init__(input_dataset, variant_tensor)
示例#10
0
 def _make_key_func(self, key_func, input_dataset):
     """Make wrapping defun for key_func."""
     self._key_func = dataset_ops.StructuredFunctionWrapper(
         key_func, self._transformation_name(), dataset=input_dataset)
     if not self._key_func.output_structure.is_compatible_with(
             structure.TensorStructure(dtypes.int64, [])):
         raise ValueError(
             "`key_func` must return a single tf.int64 tensor. "
             "Got type=%s and shape=%s" %
             (self._key_func.output_types, self._key_func.output_shapes))
示例#11
0
 def _make_finalize_func(self, finalize_func):
   """Make wrapping Defun for finalize_func."""
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       finalize_func, "tf.contrib.data.group_by_reducer()",
       input_classes=self._state_classes, input_shapes=self._state_shapes,
       input_types=self._state_types)
   self._finalize_func = wrapped_func.function
   self._output_classes = wrapped_func.output_classes
   self._output_shapes = wrapped_func.output_shapes
   self._output_types = wrapped_func.output_types
示例#12
0
 def _make_init_func(self, init_func):
   """Make wrapping Defun for init_func."""
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       init_func, "tf.contrib.data.group_by_reducer()",
       input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
       input_types=dtypes.int64)
   self._init_func = wrapped_func.function
   self._state_classes = wrapped_func.output_classes
   self._state_shapes = wrapped_func.output_shapes
   self._state_types = wrapped_func.output_types
示例#13
0
    def __init__(
        self,
        input_dataset,
        map_func,
        cycle_length,
        block_length,
        num_parallel_calls,
        deterministic,
        buffer_output_elements=None,  # backward compatibility with TF2.0
        prefetch_input_elements=None):  # backward compatibility with TF2.0
        """See `Dataset.interleave()` for details."""
        self._input_dataset = input_dataset
        self._map_func = ds.StructuredFunctionWrapper(
            map_func._func, self._transformation_name(), input_dataset)
        self._cycle_length = cycle_length
        self._block_length = block_length
        self._buffer_output_elements = buffer_output_elements
        self._prefetch_input_elements = prefetch_input_elements
        self._num_parallel_calls = num_parallel_calls
        self._deterministic = deterministic

        if (buffer_output_elements and buffer_output_elements != ds_helpers.autotune_flag()) or \
           (prefetch_input_elements and prefetch_input_elements != ds_helpers.autotune_flag()):
            variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
                input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._map_func.function.captured_inputs,  # pylint: disable=protected-access
                self._cycle_length,
                self._block_length,
                self._buffer_output_elements,
                self._prefetch_input_elements,
                self._num_parallel_calls,
                f=self._map_func.function,
                deterministic=deterministic,
                **self._flat_structure)
        elif deterministic != "default":
            variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
                input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._map_func.function.captured_inputs,  # pylint: disable=protected-access
                self._cycle_length,
                self._block_length,
                self._num_parallel_calls,
                f=self._map_func.function,
                deterministic=deterministic_string,
                **self._flat_structure)
        else:
            variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v2(
                input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._map_func.function.captured_inputs,  # pylint: disable=protected-access
                self._cycle_length,
                self._block_length,
                self._num_parallel_calls,
                f=self._map_func.function,
                **self._flat_structure)
        super(TntParallelInterleaveDataset,
              self).__init__(input_dataset, variant_tensor)
示例#14
0
    def __init__(self, input_dataset, map_func):
        """See `map_x_dataset()` for details."""
        super(_MapXDataset, self).__init__(input_dataset)
        self._input_dataset = input_dataset

        wrapped_func = dataset_ops.StructuredFunctionWrapper(
            map_func, self._transformation_name(), dataset=input_dataset)
        self._output_classes = wrapped_func.output_classes
        self._output_shapes = wrapped_func.output_shapes
        self._output_types = wrapped_func.output_types
        self._map_func = wrapped_func.function
示例#15
0
  def _make_key_func(self, key_func, input_dataset):
    """Make wrapping defun for key_func."""

    def key_func_wrapper(*args):
      return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
    self._key_func = dataset_ops.StructuredFunctionWrapper(
        key_func_wrapper, self._transformation_name(), dataset=input_dataset)
    if not self._key_func.output_structure.is_compatible_with(
        structure.TensorStructure(dtypes.int64, [])):
      raise ValueError(
          "`key_func` must return a single tf.int64 scalar tensor.")
示例#16
0
 def __init__(self, input_dataset, predicate, use_legacy_function=False):
     """See `Dataset.filter()` for details."""
     self._input_dataset = input_dataset
     self._predicate = ds.StructuredFunctionWrapper(
         predicate._func, self._transformation_name(), input_dataset)
     variant_tensor = gen_dataset_ops.filter_dataset(
         input_dataset._variant_tensor,  # pylint: disable=protected-access
         other_arguments=self._predicate.function.captured_inputs,
         predicate=self._predicate.function,
         **self._flat_structure)
     super(TntFilterDataset, self).__init__(input_dataset, variant_tensor)
示例#17
0
  def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
    """See `Dataset.map()` for details."""
    super(_MapOnGpuDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    self._use_inter_op_parallelism = use_inter_op_parallelism

    self._map_func = dataset_ops.StructuredFunctionWrapper(
        map_func,
        self._transformation_name(),
        dataset=input_dataset,
        defun_kwargs={"experimental_ints_on_device": True})
示例#18
0
 def __init__(self, input_dataset, map_func):
     """See `Dataset.flat_map()` for details."""
     self._input_dataset = input_dataset
     self._map_func = ds.StructuredFunctionWrapper(
         map_func._func, self._transformation_name(), input_dataset)
     variant_tensor = gen_dataset_ops.flat_map_dataset(
         self._input_dataset._variant_tensor,  # pylint: disable=protected-access
         self._map_func.function.captured_inputs,
         f=self._map_func.function,
         **self._flat_structure)
     super(TntFlatMapDataset, self).__init__(input_dataset, variant_tensor)
示例#19
0
 def _make_init_func(self, init_func):
     """Make wrapping defun for init_func."""
     wrapped_func = dataset_ops.StructuredFunctionWrapper(
         init_func,
         self._transformation_name(),
         input_classes=ops.Tensor,
         input_shapes=tensor_shape.scalar(),
         input_types=dtypes.int64)
     self._init_func = wrapped_func.function
     self._state_classes = wrapped_func.output_classes
     self._state_shapes = wrapped_func.output_shapes
     self._state_types = wrapped_func.output_types
示例#20
0
 def _make_key_func(self, key_func, input_dataset):
     """Make wrapping defun for key_func."""
     wrapped_func = dataset_ops.StructuredFunctionWrapper(
         key_func, self._transformation_name(), dataset=input_dataset)
     if not (wrapped_func.output_types == dtypes.int64
             and wrapped_func.output_shapes.is_compatible_with(
                 tensor_shape.scalar())):
         raise ValueError(
             "`key_func` must return a single tf.int64 tensor. "
             "Got type=%s and shape=%s" %
             (wrapped_func.output_types, wrapped_func.output_shapes))
     self._key_func = wrapped_func.function
示例#21
0
 def _make_key_func(self, key_func, input_dataset):
   """Make wrapping Defun for key_func."""
   def key_func_wrapper(*args):
     return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset)
   if not (
       wrapped_func.output_types == dtypes.int64 and
       wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
     raise ValueError(
         "`key_func` must return a single tf.int64 scalar tensor.")
   self._key_func = wrapped_func.function
示例#22
0
 def _make_key_func(self, key_func, input_dataset):
   """Make wrapping defun for key_func."""
   self._key_func = dataset_ops.StructuredFunctionWrapper(
       key_func, self._transformation_name(), dataset=input_dataset)
   if not self._key_func.output_structure.is_compatible_with(
       tensor_spec.TensorSpec([], dtypes.int64)):
     raise ValueError(
         f"Invalid `key_func`. Expected `key_func` to return a scalar "
         f"tf.int64 tensor, but instead `key_func` has output "
         f"types={self._key_func.output_types} "
         f"and shapes={self._key_func.output_shapes}."
     )
示例#23
0
 def _make_key_func(self, key_func, input_dataset):
   """Make wrapping Defun for key_func."""
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       key_func, "tf.data.experimental.group_by_reducer()", input_dataset)
   if not (
       wrapped_func.output_types == dtypes.int64 and
       wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
     raise ValueError(
         "`key_func` must return a single tf.int64 tensor. "
         "Got type=%s and shape=%s"
         % (wrapped_func.output_types, wrapped_func.output_shapes))
   self._key_func = wrapped_func.function
示例#24
0
 def _make_finalize_func(self, finalize_func):
     """Make wrapping defun for finalize_func."""
     wrapped_func = dataset_ops.StructuredFunctionWrapper(
         finalize_func,
         self._transformation_name(),
         input_classes=self._state_classes,
         input_shapes=self._state_shapes,
         input_types=self._state_types)
     self._finalize_func = wrapped_func.function
     self._output_classes = wrapped_func.output_classes
     self._output_shapes = wrapped_func.output_shapes
     self._output_types = wrapped_func.output_types
示例#25
0
  def _make_window_size_func(self, window_size_func):
    """Make wrapping defun for window_size_func."""

    def window_size_func_wrapper(key):
      return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
    self._window_size_func = dataset_ops.StructuredFunctionWrapper(
        window_size_func_wrapper,
        self._transformation_name(),
        input_structure=structure.TensorStructure(dtypes.int64, []))
    if not self._window_size_func.output_structure.is_compatible_with(
        structure.TensorStructure(dtypes.int64, [])):
      raise ValueError(
          "`window_size_func` must return a single tf.int64 scalar tensor.")
示例#26
0
    def __init__(self, input_dataset, map_func):
        """See `map_x_dataset()` for details."""
        super(_MapXDataset, self).__init__()
        self._input_dataset = input_dataset

        wrapped_func = dataset_ops.StructuredFunctionWrapper(
            map_func,
            "tf.contrib.data.map_x_dataset()",
            input_dataset,
            experimental_nested_dataset_support=True)
        self._output_classes = wrapped_func.output_classes
        self._output_shapes = wrapped_func.output_shapes
        self._output_types = wrapped_func.output_types
        self._map_func = wrapped_func.function
示例#27
0
 def _make_reduce_func(self, reduce_func, input_dataset):
     """Make wrapping defun for reduce_func."""
     nested_dataset = dataset_ops.DatasetSpec(input_dataset.element_spec)
     input_structure = (tensor_spec.TensorSpec([], dtypes.int64),
                        nested_dataset)
     self._reduce_func = dataset_ops.StructuredFunctionWrapper(
         reduce_func,
         self._transformation_name(),
         input_structure=input_structure)
     if not isinstance(self._reduce_func.output_structure,
                       dataset_ops.DatasetSpec):
         raise TypeError("`reduce_func` must return a `Dataset` object.")
     # pylint: disable=protected-access
     self._element_spec = (self._reduce_func.output_structure._element_spec)
示例#28
0
 def _make_window_size_func(self, window_size_func):
   """Make wrapping Defun for window_size_func."""
   def window_size_func_wrapper(key):
     return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
   wrapped_func = dataset_ops.StructuredFunctionWrapper(
       window_size_func_wrapper, "tf.contrib.data.group_by_window()",
       input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
       input_types=dtypes.int64)
   if not (
       wrapped_func.output_types == dtypes.int64 and
       wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
     raise ValueError(
         "`window_size_func` must return a single tf.int64 scalar tensor.")
   self._window_size_func = wrapped_func.function
示例#29
0
    def __init__(self,
                 input_dataset,
                 map_func,
                 batch_size,
                 num_parallel_calls,
                 drop_remainder,
                 use_legacy_function=False):
        self._input_dataset = input_dataset

        self._map_func = dataset_ops.StructuredFunctionWrapper(
            map_func,
            "tf.data.experimental.map_and_batch()",
            dataset=input_dataset,
            use_legacy_function=use_legacy_function)
        self._batch_size_t = ops.convert_to_tensor(batch_size,
                                                   dtype=dtypes.int64,
                                                   name="batch_size")
        self._num_parallel_calls_t = ops.convert_to_tensor(
            num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
        self._drop_remainder_t = ops.convert_to_tensor(drop_remainder,
                                                       dtype=dtypes.bool,
                                                       name="drop_remainder")

        constant_drop_remainder = tensor_util.constant_value(
            self._drop_remainder_t)
        # pylint: disable=protected-access
        if constant_drop_remainder:
            # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
            # or `False` (explicitly retaining the remainder).
            # pylint: disable=g-long-lambda
            self._element_spec = nest.map_structure(
                lambda component_spec: component_spec._batch(
                    tensor_util.constant_value(self._batch_size_t)),
                self._map_func.output_structure)
        else:
            self._element_spec = nest.map_structure(
                lambda component_spec: component_spec._batch(None),
                self._map_func.output_structure)
        # pylint: enable=protected-access
        variant_tensor = ged_ops.map_and_batch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._map_func.function.captured_inputs,
            f=self._map_func.function,
            batch_size=self._batch_size_t,
            num_parallel_calls=self._num_parallel_calls_t,
            drop_remainder=self._drop_remainder_t,
            preserve_cardinality=True,
            **self._flat_structure)
        super(_MapAndBatchDataset, self).__init__(input_dataset,
                                                  variant_tensor)
示例#30
0
 def __init__(self, input_dataset, map_func):
     """See `Dataset.flat_map()` for details."""
     self._input_dataset = input_dataset
     self._map_func = dataset_ops.StructuredFunctionWrapper(
         map_func,
         self._transformation_name(),
         dataset=input_dataset,
         defun_kwargs={"_executor": "SINGLE_THREADED_EXECUTOR"})
     self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
     variant_tensor = gen_dataset_ops.flat_map_dataset(
         input_dataset._variant_tensor,  # pylint: disable=protected-access
         self._map_func.function.captured_inputs,
         f=self._map_func.function,
         **self._flat_structure)
     super(SingleThreadedFlatMapDataset,
           self).__init__(input_dataset, variant_tensor)