Esempio n. 1
0
    def __init__(self,
                 filenames,
                 compression_type=None,
                 buffer_size=None,
                 name=None):
        """Creates a `TFRecordDataset`.

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
        bytes in the read buffer. 0 means no buffering.
      name: (Optional.) A name for the tf.data operation.
    """
        self._filenames = filenames
        self._compression_type = convert.optional_param_to_tensor(
            "compression_type",
            compression_type,
            argument_default="",
            argument_dtype=dtypes.string)
        self._buffer_size = convert.optional_param_to_tensor(
            "buffer_size",
            buffer_size,
            argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
        self._metadata = dataset_metadata_pb2.Metadata()
        if name:
            self._metadata.name = dataset_ops._validate_and_encode(name)

        variant_tensor = gen_dataset_ops.tf_record_dataset(
            self._filenames,
            self._compression_type,
            self._buffer_size,
            metadata=self._metadata.SerializeToString())
        super(_TFRecordDataset, self).__init__(variant_tensor)
Esempio n. 2
0
    def __init__(self,
                 input_dataset,
                 map_func,
                 cycle_length,
                 block_length,
                 sloppy,
                 buffer_output_elements,
                 prefetch_input_elements,
                 name=None):
        """See `tf.data.experimental.parallel_interleave()` for details."""
        self._input_dataset = input_dataset
        self._map_func = structured_function.StructuredFunctionWrapper(
            map_func, self._transformation_name(), dataset=input_dataset)
        if not isinstance(self._map_func.output_structure,
                          dataset_ops.DatasetSpec):
            raise TypeError(
                "The `map_func` argument must return a `Dataset` object. Got "
                f"{_get_type(self._map_func.output_structure)!r}.")
        self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
        self._cycle_length = ops.convert_to_tensor(cycle_length,
                                                   dtype=dtypes.int64,
                                                   name="cycle_length")
        self._block_length = ops.convert_to_tensor(block_length,
                                                   dtype=dtypes.int64,
                                                   name="block_length")
        self._buffer_output_elements = convert.optional_param_to_tensor(
            "buffer_output_elements",
            buffer_output_elements,
            argument_default=2 * block_length)
        self._prefetch_input_elements = convert.optional_param_to_tensor(
            "prefetch_input_elements",
            prefetch_input_elements,
            argument_default=2 * cycle_length)
        if sloppy is None:
            self._deterministic = "default"
        elif sloppy:
            self._deterministic = "false"
        else:
            self._deterministic = "true"
        self._metadata = dataset_metadata_pb2.Metadata()
        if name:
            self._metadata.name = dataset_ops._validate_and_encode(name)
        kwargs = self._flat_structure
        if name or compat.forward_compatible(2021, 9, 30):
            kwargs["metadata"] = self._metadata.SerializeToString()

        variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._map_func.function.captured_inputs,
            self._cycle_length,
            self._block_length,
            self._buffer_output_elements,
            self._prefetch_input_elements,
            f=self._map_func.function,
            deterministic=self._deterministic,
            **kwargs)
        super(ParallelInterleaveDataset,
              self).__init__(input_dataset, variant_tensor)
Esempio n. 3
0
    def __init__(self,
                 filenames,
                 record_bytes,
                 header_bytes=None,
                 footer_bytes=None,
                 buffer_size=None,
                 compression_type=None,
                 name=None):
        """Creates a `FixedLengthRecordDataset`.

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      record_bytes: A `tf.int64` scalar representing the number of bytes in each
        record.
      header_bytes: (Optional.) A `tf.int64` scalar representing the number of
        bytes to skip at the start of a file.
      footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
        bytes to ignore at the end of a file.
      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
        bytes to buffer when reading.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      name: (Optional.) A name for the tf.data operation.
    """
        self._filenames = filenames
        self._record_bytes = ops.convert_to_tensor(record_bytes,
                                                   dtype=dtypes.int64,
                                                   name="record_bytes")
        self._header_bytes = convert.optional_param_to_tensor(
            "header_bytes", header_bytes)
        self._footer_bytes = convert.optional_param_to_tensor(
            "footer_bytes", footer_bytes)
        self._buffer_size = convert.optional_param_to_tensor(
            "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
        self._compression_type = convert.optional_param_to_tensor(
            "compression_type",
            compression_type,
            argument_default="",
            argument_dtype=dtypes.string)
        self._metadata = dataset_metadata_pb2.Metadata()
        if name:
            self._metadata.name = dataset_ops._validate_and_encode(name)

        variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
            self._filenames,
            self._header_bytes,
            self._record_bytes,
            self._footer_bytes,
            self._buffer_size,
            self._compression_type,
            metadata=self._metadata.SerializeToString())
        super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
Esempio n. 4
0
    def __init__(self,
                 filenames,
                 compression_type=None,
                 buffer_size=None,
                 name=None):
        """Creates a `TextLineDataset`.

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
        to buffer. A value of 0 results in the default buffering values chosen
        based on the compression type.
      name: (Optional.) A name for the tf.data operation.
    """
        self._filenames = filenames
        self._compression_type = convert.optional_param_to_tensor(
            "compression_type",
            compression_type,
            argument_default="",
            argument_dtype=dtypes.string)
        self._buffer_size = convert.optional_param_to_tensor(
            "buffer_size",
            buffer_size,
            argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
        self._metadata = dataset_metadata_pb2.Metadata()
        if name:
            self._metadata.name = dataset_ops._validate_and_encode(name)
        kwargs = {}
        if name or compat.forward_compatible(2021, 9, 30):
            kwargs["metadata"] = self._metadata.SerializeToString()

        variant_tensor = gen_dataset_ops.text_line_dataset(
            self._filenames, self._compression_type, self._buffer_size,
            **kwargs)
        super(_TextLineDataset, self).__init__(variant_tensor)