예제 #1
0
def features_encode_decode(features_dict, example, decoders):
    """Runs the full pipeline: encode > write > tmp files > read > decode."""
    # Encode example
    encoded_example = features_dict.encode_example(example)

    # Serialize/deserialize the example
    specs = features_dict.get_serialized_info()
    serializer = example_serializer.ExampleSerializer(specs)
    parser = example_parser.ExampleParser(specs)

    serialized_example = serializer.serialize_example(encoded_example)
    ds = tf.data.Dataset.from_tensors(serialized_example)
    ds = ds.map(parser.parse_example)

    # Decode the example
    decode_fn = functools.partial(
        features_dict.decode_example,
        decoders=decoders,
    )
    ds = ds.map(decode_fn)

    if tf.executing_eagerly():
        out_tensor = next(iter(ds))
    else:
        out_tensor = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
    out_numpy = dataset_utils.as_numpy(out_tensor)
    return out_tensor, out_numpy
예제 #2
0
    def __init__(self,
                 example_specs,
                 path,
                 hash_salt,
                 disable_shuffling,
                 file_format=file_adapters.DEFAULT_FILE_FORMAT):
        """Init BeamWriter.

    Args:
      example_specs:
      path: str, path where to write tfrecord file. Eg:
        "/foo/mnist-train.tfrecord".
        The suffix (eg: `.00000-of-00004` will be added by the BeamWriter. Note
          that file "{path}.shard_lengths.json" is also created. It contains
          a list with the number of examples in each final shard. Eg:
            "[10,11,10,11]".
      hash_salt: string, the salt to use for hashing of keys.
      disable_shuffling: bool, specifies whether to shuffle the records.
      file_format: file_adapters.FileFormat, format of the record files in which
        the dataset will be read/written from.
    """
        self._original_state = dict(example_specs=example_specs,
                                    path=path,
                                    hash_salt=hash_salt,
                                    disable_shuffling=disable_shuffling,
                                    file_format=file_format)
        self._path = os.fspath(path)
        self._split_info_path = "%s.split_info.json" % path
        self._serializer = example_serializer.ExampleSerializer(example_specs)
        self._example_specs = example_specs
        self._hasher = hashing.Hasher(hash_salt)
        self._split_info = None
        self._file_format = file_format
        self._disable_shuffling = disable_shuffling
예제 #3
0
    def __init__(
        self,
        example_specs,
        filename_template: naming.ShardedFileTemplate,
        hash_salt,
        disable_shuffling: bool,
        file_format=file_adapters.DEFAULT_FILE_FORMAT,
    ):
        """Initializes Writer.

    Args:
      example_specs: spec to build ExampleSerializer.
      filename_template: template to format sharded filenames.
      hash_salt (str or bytes): salt to hash keys.
      disable_shuffling (bool): Specifies whether to shuffle the records.
      file_format (FileFormat): format of the record files in which the dataset
        should be written in.
    """
        self._example_specs = example_specs
        self._serializer = example_serializer.ExampleSerializer(example_specs)
        self._shuffler = shuffle.Shuffler(dirpath=filename_template.data_dir,
                                          hash_salt=hash_salt,
                                          disable_shuffling=disable_shuffling)
        self._num_examples = 0
        self._filename_template = filename_template
        self._file_format = file_format
예제 #4
0
    def __init__(
        self,
        example_specs,
        path,
        hash_salt,
        disable_shuffling: bool,
        file_format=file_adapters.DEFAULT_FILE_FORMAT,
    ):
        """Initializes Writer.

    Args:
      example_specs: spec to build ExampleSerializer.
      path (str): path where records should be written in.
      hash_salt (str or bytes): salt to hash keys.
      disable_shuffling (bool): Specifies whether to shuffle the records.
      file_format (FileFormat): format of the record files in which the dataset
        should be written in.
    """
        self._example_specs = example_specs
        self._serializer = example_serializer.ExampleSerializer(example_specs)
        self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt,
                                          disable_shuffling)
        self._num_examples = 0
        self._path = path
        self._file_format = file_format
예제 #5
0
    def __init__(self,
                 ds_info: dataset_info.DatasetInfo,
                 max_examples_per_shard: int,
                 overwrite: bool = True):
        """Creates a SequentialWriter.

    Args:
      ds_info: DatasetInfo for this dataset.
      max_examples_per_shard: maximum number of examples to write per shard.
      overwrite: if True, it ignores and overwrites any existing data.
        Otherwise, it loads the existing dataset and appends the new data (new
        data will always be created as new shards).
    """

        self._data_dir = ds_info.data_dir
        self._ds_name = ds_info.name
        self._ds_info = ds_info
        if not overwrite:
            try:
                self._ds_info.read_from_directory(self._data_dir)
                # read_from_directory does some checks but not on the dataset name.
                if self._ds_info.name != self._ds_name:
                    raise ValueError(
                        f'Trying to append a dataset with name {ds_info.name}'
                        f' to an existing dataset with name {self._ds_info.name}'
                    )
            except FileNotFoundError:
                self._ds_info.set_file_format(
                    file_format=file_adapters.FileFormat.TFRECORD,
                    # if it was set, we want this to fail to warn the user
                    override=False)
        else:
            self._ds_info.set_file_format(
                file_format=file_adapters.FileFormat.TFRECORD,
                # if it was set, we want this to fail to warn the user
                override=False)

        self._filetype_suffix = ds_info.file_format.file_suffix
        self._max_examples_per_shard = max_examples_per_shard
        self._splits = {}
        if not overwrite:
            for split_name, split in ds_info.splits.items():
                self._splits[split_name] = _initialize_split(
                    split_name=split_name,
                    data_directory=self._data_dir,
                    ds_name=self._ds_name,
                    filetype_suffix=self._filetype_suffix,
                    shard_lengths=split.shard_lengths,
                    num_bytes=split.num_bytes)
        self._serializer = example_serializer.ExampleSerializer(
            self._ds_info.features.get_serialized_info())
예제 #6
0
  def _build_from_pcollection(
      self,
      split_name: str,
      generator: 'beam.PCollection[KeyExample]',
      filename_template: naming.ShardedFileTemplate,
      disable_shuffling: bool,
  ) -> _SplitInfoFuture:
    """Split generator for `beam.PCollection`."""
    # TODO(tfds): Should try to add support to `max_examples_per_split`
    beam = lazy_imports_lib.lazy_imports.apache_beam

    beam_writer = writer_lib.BeamWriter(
        serializer=example_serializer.ExampleSerializer(
            self._features.get_serialized_info()),
        filename_template=filename_template,
        hash_salt=split_name,
        disable_shuffling=disable_shuffling,
        file_format=self._file_format,
    )

    def _encode_example(key_ex, encode_fn=self._features.encode_example):
      # We do not access self._features in this function to avoid pickling the
      # entire class.
      return key_ex[0], encode_fn(key_ex[1])

    # Note: We need to wrap the pipeline in a PTransform to avoid
    # errors due to duplicated ``>> beam_labels`
    @beam.ptransform_fn
    def _encode_pcollection(pipeline):
      """PTransformation which build a single split."""
      pcoll_examples = pipeline | 'Encode' >> beam.Map(_encode_example)
      return beam_writer.write_from_pcollection(pcoll_examples)

    # Add the PCollection to the pipeline
    _ = generator | f'{split_name}_write' >> _encode_pcollection()  # pylint: disable=no-value-for-parameter

    def _resolve_future():
      if self._in_contextmanager:
        raise AssertionError('`future.result()` should be called after the '
                             '`maybe_beam_pipeline` contextmanager.')
      logging.info('Retrieving split info for %s...', split_name)
      shard_lengths, total_size = beam_writer.finalize()
      return splits_lib.SplitInfo(
          name=split_name,
          shard_lengths=shard_lengths,
          num_bytes=total_size,
          filename_template=filename_template,
      )

    return _SplitInfoFuture(_resolve_future)
예제 #7
0
    def __init__(self, example_specs, path, hash_salt):
        """Init BeamWriter.

    Args:
      example_specs:
      path: str, path where to write tfrecord file. Eg:
        "/foo/mnist-train.tfrecord".
        The suffix (eg: `.00000-of-00004` will be added by the BeamWriter.
        Note that file "{path}.shard_lengths.json" is also created. It contains
          a list with the number of examples in each final shard. Eg:
          "[10,11,10,11]".
      hash_salt: string, the salt to use for hashing of keys.
    """
        self._original_state = dict(example_specs=example_specs,
                                    path=path,
                                    hash_salt=hash_salt)
        self._path = path
        self._shards_length_path = "%s.shard_lengths.json" % path
        self._serializer = example_serializer.ExampleSerializer(example_specs)
        self._hasher = hashing.Hasher(hash_salt)
        self._shard_lengths = None
예제 #8
0
    def __init__(
        self,
        example_specs,
        filename_template: naming.ShardedFileTemplate,
        hash_salt,
        disable_shuffling: bool,
        file_format: file_adapters.FileFormat = file_adapters.
        DEFAULT_FILE_FORMAT,
    ):
        """Init BeamWriter.

    Note that file "{filepath_prefix}.shard_lengths.json" is also created. It
    contains a list with the number of examples in each final shard. Eg:
    "[10,11,10,11]".

    Args:
      example_specs:
      filename_template: template to format sharded filenames.
      hash_salt: string, the salt to use for hashing of keys.
      disable_shuffling: bool, specifies whether to shuffle the records.
      file_format: file_adapters.FileFormat, format of the record files in which
        the dataset will be read/written from.
    """
        self._original_state = dict(example_specs=example_specs,
                                    filename_template=filename_template,
                                    hash_salt=hash_salt,
                                    disable_shuffling=disable_shuffling,
                                    file_format=file_format)
        self._filename_template = filename_template
        self._split_info_path = f"{filename_template.filepath_prefix()}.split_info.json"
        self._serializer = example_serializer.ExampleSerializer(example_specs)
        self._example_specs = example_specs
        self._hasher = hashing.Hasher(hash_salt)
        self._split_info = None
        self._file_format = file_format
        self._disable_shuffling = disable_shuffling
예제 #9
0
  def _build_from_generator(
      self,
      split_name: str,
      generator: Iterable[KeyExample],
      filename_template: naming.ShardedFileTemplate,
      disable_shuffling: bool,
  ) -> _SplitInfoFuture:
    """Split generator for example generators.

    Args:
      split_name: str,
      generator: Iterable[KeyExample],
      filename_template: Template to format the filename for a shard.
      disable_shuffling: Specifies whether to shuffle the examples,

    Returns:
      future: The future containing the `tfds.core.SplitInfo`.
    """
    if self._max_examples_per_split is not None:
      logging.warning('Splits capped at %s examples max.',
                      self._max_examples_per_split)
      generator = itertools.islice(generator, self._max_examples_per_split)
      total_num_examples = self._max_examples_per_split
    else:
      # If dataset info has been pre-downloaded from the internet,
      # we can use the pre-computed number of example for the progression bar.
      split_info = self._split_dict.get(split_name)
      if split_info and split_info.num_examples:
        total_num_examples = split_info.num_examples
      else:
        total_num_examples = None

    writer = writer_lib.Writer(
        serializer=example_serializer.ExampleSerializer(
            self._features.get_serialized_info()),
        filename_template=filename_template,
        hash_salt=split_name,
        disable_shuffling=disable_shuffling,
        # TODO(weide) remove this because it's already in filename_template?
        file_format=self._file_format,
    )
    for key, example in utils.tqdm(
        generator,
        desc=f'Generating {split_name} examples...',
        unit=' examples',
        total=total_num_examples,
        leave=False,
    ):
      try:
        example = self._features.encode_example(example)
      except Exception as e:  # pylint: disable=broad-except
        utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n')
      writer.write(key, example)
    shard_lengths, total_size = writer.finalize()

    split_info = splits_lib.SplitInfo(
        name=split_name,
        shard_lengths=shard_lengths,
        num_bytes=total_size,
        filename_template=filename_template,
    )
    return _SplitInfoFuture(lambda: split_info)
예제 #10
0
 def _example_serializer(self):
   from tensorflow_datasets.core import example_serializer  # pytype: disable=import-error  # pylint: disable=g-import-not-at-top
   example_specs = self.get_serialized_info()
   return example_serializer.ExampleSerializer(example_specs)
예제 #11
0
 def __init__(self, example_specs, path, hash_salt):
     self._serializer = example_serializer.ExampleSerializer(example_specs)
     self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt)
     self._num_examples = 0
     self._path = path
예제 #12
0
 def __init__(self, example_specs):
   super(TFRecordExampleAdapter, self).__init__(example_specs)
   self._serializer = example_serializer.ExampleSerializer(
       example_specs)
   self._parser = example_parser.ExampleParser(example_specs)