Exemplo n.º 1
0
def _set_save_dataset_attributes(dataset, shard_func, path):
    """Sets parameters for SaveDatasetOp and SaveDatasetV2Op."""
    if shard_func is None:
        use_shard_func = False
        shard_func = lambda *x: None  # a dummy function that will not be used
    else:
        use_shard_func = True

    wrapped_func = structured_function.StructuredFunctionWrapper(
        shard_func,
        "save()",
        input_structure=dataset.element_spec,
        add_to_graph=False)

    encoded = nested_structure_coder.encode_structure(dataset.element_spec)
    gfile.MakeDirs(path)
    with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME), "wb") as f:
        f.write(encoded.SerializeToString())

    path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
    shard_func = wrapped_func.function
    shard_func.add_to_graph(ops.get_default_graph())

    # pylint: disable=protected-access
    dataset._apply_debug_options()
    return dataset, shard_func, use_shard_func, path,
Exemplo n.º 2
0
    def __init__(self,
                 input_dataset,
                 map_func,
                 cycle_length,
                 block_length,
                 sloppy,
                 buffer_output_elements,
                 prefetch_input_elements,
                 name=None):
        """See `tf.data.experimental.parallel_interleave()` for details."""
        self._input_dataset = input_dataset
        self._map_func = structured_function.StructuredFunctionWrapper(
            map_func, self._transformation_name(), dataset=input_dataset)
        if not isinstance(self._map_func.output_structure,
                          dataset_ops.DatasetSpec):
            raise TypeError(
                "The `map_func` argument must return a `Dataset` object. Got "
                f"{_get_type(self._map_func.output_structure)!r}.")
        self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
        self._cycle_length = ops.convert_to_tensor(cycle_length,
                                                   dtype=dtypes.int64,
                                                   name="cycle_length")
        self._block_length = ops.convert_to_tensor(block_length,
                                                   dtype=dtypes.int64,
                                                   name="block_length")
        self._buffer_output_elements = convert.optional_param_to_tensor(
            "buffer_output_elements",
            buffer_output_elements,
            argument_default=2 * block_length)
        self._prefetch_input_elements = convert.optional_param_to_tensor(
            "prefetch_input_elements",
            prefetch_input_elements,
            argument_default=2 * cycle_length)
        if sloppy is None:
            self._deterministic = "default"
        elif sloppy:
            self._deterministic = "false"
        else:
            self._deterministic = "true"
        self._metadata = dataset_metadata_pb2.Metadata()
        if name:
            self._metadata.name = dataset_ops._validate_and_encode(name)
        kwargs = self._flat_structure
        if name or compat.forward_compatible(2021, 9, 30):
            kwargs["metadata"] = self._metadata.SerializeToString()

        variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._map_func.function.captured_inputs,
            self._cycle_length,
            self._block_length,
            self._buffer_output_elements,
            self._prefetch_input_elements,
            f=self._map_func.function,
            deterministic=self._deterministic,
            **kwargs)
        super(ParallelInterleaveDataset,
              self).__init__(input_dataset, variant_tensor)
Exemplo n.º 3
0
 def _make_key_func(self, key_func, input_dataset):
     """Make wrapping defun for key_func."""
     self._key_func = structured_function.StructuredFunctionWrapper(
         key_func, self._transformation_name(), dataset=input_dataset)
     if not self._key_func.output_structure.is_compatible_with(
             tensor_spec.TensorSpec([], dtypes.int64)):
         raise ValueError(
             f"Invalid `key_func`. Expected `key_func` to return a scalar "
             f"tf.int64 tensor, but instead `key_func` has output "
             f"types={self._key_func.output_types} "
             f"and shapes={self._key_func.output_shapes}.")
Exemplo n.º 4
0
    def __init__(self,
                 input_dataset,
                 map_func,
                 batch_size,
                 num_parallel_calls,
                 drop_remainder,
                 use_legacy_function=False):
        self._input_dataset = input_dataset

        self._map_func = structured_function.StructuredFunctionWrapper(
            map_func,
            "tf.data.experimental.map_and_batch()",
            dataset=input_dataset,
            use_legacy_function=use_legacy_function)
        self._batch_size_t = ops.convert_to_tensor(batch_size,
                                                   dtype=dtypes.int64,
                                                   name="batch_size")
        self._num_parallel_calls_t = ops.convert_to_tensor(
            num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
        self._drop_remainder_t = ops.convert_to_tensor(drop_remainder,
                                                       dtype=dtypes.bool,
                                                       name="drop_remainder")

        constant_drop_remainder = tensor_util.constant_value(
            self._drop_remainder_t)
        # pylint: disable=protected-access
        if constant_drop_remainder:
            # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
            # or `False` (explicitly retaining the remainder).
            # pylint: disable=g-long-lambda
            self._element_spec = nest.map_structure(
                lambda component_spec: component_spec._batch(
                    tensor_util.constant_value(self._batch_size_t)),
                self._map_func.output_structure)
        else:
            self._element_spec = nest.map_structure(
                lambda component_spec: component_spec._batch(None),
                self._map_func.output_structure)
        # pylint: enable=protected-access
        variant_tensor = ged_ops.map_and_batch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._map_func.function.captured_inputs,
            f=self._map_func.function,
            batch_size=self._batch_size_t,
            num_parallel_calls=self._num_parallel_calls_t,
            drop_remainder=self._drop_remainder_t,
            preserve_cardinality=True,
            **self._flat_structure)
        super(_MapAndBatchDataset, self).__init__(input_dataset,
                                                  variant_tensor)
Exemplo n.º 5
0
 def __init__(self, input_dataset, map_func):
     """See `Dataset.flat_map()` for details."""
     self._input_dataset = input_dataset
     self._map_func = structured_function.StructuredFunctionWrapper(
         map_func,
         self._transformation_name(),
         dataset=input_dataset,
         defun_kwargs={"_executor": "SINGLE_THREADED_EXECUTOR"})
     self._structure = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
     variant_tensor = gen_dataset_ops.flat_map_dataset(
         input_dataset._variant_tensor,  # pylint: disable=protected-access
         self._map_func.function.captured_inputs,
         f=self._map_func.function,
         **self._flat_structure)
     super(SingleThreadedFlatMapDataset,
           self).__init__(input_dataset, variant_tensor)
Exemplo n.º 6
0
    def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
        """See `Dataset.map()` for details."""
        self._input_dataset = input_dataset
        self._use_inter_op_parallelism = use_inter_op_parallelism

        self._map_func = structured_function.StructuredFunctionWrapper(
            map_func,
            self._transformation_name(),
            dataset=input_dataset,
            defun_kwargs={"experimental_ints_on_device": True})
        variant_tensor = ged_ops.experimental_map_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._map_func.function.captured_inputs,
            f=self._map_func.function,
            use_inter_op_parallelism=self._use_inter_op_parallelism,
            **self._flat_structure)
        super(_MapOnGpuDataset, self).__init__(input_dataset, variant_tensor)
Exemplo n.º 7
0
    def __init__(self,
                 path,
                 element_spec=None,
                 compression=None,
                 reader_func=None):

        if reader_func is None:
            reader_func = lambda datasets: datasets.interleave(  # pylint:disable=g-long-lambda
                lambda x: x,
                cycle_length=multiprocessing.cpu_count(),
                num_parallel_calls=dataset_ops.AUTOTUNE)

        self._path = path
        if element_spec is None:
            if not context.executing_eagerly():
                raise ValueError(
                    "In graph mode the `element_spec` argument must be provided."
                )
            with gfile.GFile(os.path.join(path, DATASET_SPEC_FILENAME),
                             "rb") as f:
                encoded_spec = f.read()
            struct_pb = nested_structure_coder.struct_pb2.StructuredValue()
            struct_pb.ParseFromString(encoded_spec)
            coder = nested_structure_coder.StructureCoder()
            spec = coder.decode_proto(struct_pb)
            self._element_spec = spec
        else:
            self._element_spec = element_spec
        self._compression = compression
        self._reader_func = structured_function.StructuredFunctionWrapper(
            reader_func,
            "load()",
            # Dataset of datasets of input elements
            input_structure=dataset_ops.DatasetSpec(
                dataset_ops.DatasetSpec(self._element_spec)))

        variant_tensor = gen_experimental_dataset_ops.load_dataset(
            path,
            reader_func_other_args=self._reader_func.function.captured_inputs,
            compression=compression,
            reader_func=self._reader_func.function,
            **self._flat_structure)
        super(_LoadDataset, self).__init__(variant_tensor)
Exemplo n.º 8
0
 def _make_finalize_func(self, finalize_func):
     """Make wrapping defun for finalize_func."""
     self._finalize_func = structured_function.StructuredFunctionWrapper(
         finalize_func,
         self._transformation_name(),
         input_structure=self._state_structure)
Exemplo n.º 9
0
    def _make_reduce_func(self, reduce_func, input_dataset):
        """Make wrapping defun for reduce_func."""

        # Iteratively rerun the reduce function until reaching a fixed point on
        # `self._state_structure`.
        self._state_structure = self._init_func.output_structure
        state_types = self._init_func.output_types
        state_shapes = self._init_func.output_shapes
        state_classes = self._init_func.output_classes
        need_to_rerun = True
        while need_to_rerun:

            wrapped_func = structured_function.StructuredFunctionWrapper(
                reduce_func,
                self._transformation_name(),
                input_structure=(self._state_structure,
                                 input_dataset.element_spec),
                add_to_graph=False)

            # Extract and validate class information from the returned values.
            for new_state_class, state_class in zip(
                    nest.flatten(wrapped_func.output_classes),
                    nest.flatten(state_classes)):
                if not issubclass(new_state_class, state_class):
                    raise TypeError(
                        f"Invalid `reducer`. The output class of the "
                        f"`reducer.reduce_func` {wrapped_func.output_classes}, "
                        f"does not match the class of the reduce state "
                        f"{self._state_classes}.")

            # Extract and validate type information from the returned values.
            for new_state_type, state_type in zip(
                    nest.flatten(wrapped_func.output_types),
                    nest.flatten(state_types)):
                if new_state_type != state_type:
                    raise TypeError(
                        f"Invalid `reducer`. The element types for the new state "
                        f"{wrapped_func.output_types} do not match the element types "
                        f"of the old state {self._init_func.output_types}.")

            # Extract shape information from the returned values.
            flat_state_shapes = nest.flatten(state_shapes)
            flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
            weakened_state_shapes = [
                original.most_specific_compatible_shape(new) for original, new
                in zip(flat_state_shapes, flat_new_state_shapes)
            ]

            need_to_rerun = False
            for original_shape, weakened_shape in zip(flat_state_shapes,
                                                      weakened_state_shapes):
                if original_shape.ndims is not None and (
                        weakened_shape.ndims is None or
                        original_shape.as_list() != weakened_shape.as_list()):
                    need_to_rerun = True
                    break

            if need_to_rerun:
                state_shapes = nest.pack_sequence_as(
                    self._init_func.output_shapes, weakened_state_shapes)
                self._state_structure = structure.convert_legacy_structure(
                    state_types, state_shapes, state_classes)

        self._reduce_func = wrapped_func
        self._reduce_func.function.add_to_graph(ops.get_default_graph())
Exemplo n.º 10
0
 def _make_init_func(self, init_func):
     """Make wrapping defun for init_func."""
     self._init_func = structured_function.StructuredFunctionWrapper(
         init_func,
         self._transformation_name(),
         input_structure=tensor_spec.TensorSpec([], dtypes.int64))