def features_encode_decode(features_dict, example, decoders): """Runs the full pipeline: encode > write > tmp files > read > decode.""" # Encode example encoded_example = features_dict.encode_example(example) # Serialize/deserialize the example specs = features_dict.get_serialized_info() serializer = example_serializer.ExampleSerializer(specs) parser = example_parser.ExampleParser(specs) serialized_example = serializer.serialize_example(encoded_example) ds = tf.data.Dataset.from_tensors(serialized_example) ds = ds.map(parser.parse_example) # Decode the example decode_fn = functools.partial( features_dict.decode_example, decoders=decoders, ) ds = ds.map(decode_fn) if tf.executing_eagerly(): out_tensor = next(iter(ds)) else: out_tensor = tf.compat.v1.data.make_one_shot_iterator(ds).get_next() out_numpy = dataset_utils.as_numpy(out_tensor) return out_tensor, out_numpy
def __init__(self, example_specs, path, hash_salt, disable_shuffling, file_format=file_adapters.DEFAULT_FILE_FORMAT): """Init BeamWriter. Args: example_specs: path: str, path where to write tfrecord file. Eg: "/foo/mnist-train.tfrecord". The suffix (eg: `.00000-of-00004` will be added by the BeamWriter. Note that file "{path}.shard_lengths.json" is also created. It contains a list with the number of examples in each final shard. Eg: "[10,11,10,11]". hash_salt: string, the salt to use for hashing of keys. disable_shuffling: bool, specifies whether to shuffle the records. file_format: file_adapters.FileFormat, format of the record files in which the dataset will be read/written from. """ self._original_state = dict(example_specs=example_specs, path=path, hash_salt=hash_salt, disable_shuffling=disable_shuffling, file_format=file_format) self._path = os.fspath(path) self._split_info_path = "%s.split_info.json" % path self._serializer = example_serializer.ExampleSerializer(example_specs) self._example_specs = example_specs self._hasher = hashing.Hasher(hash_salt) self._split_info = None self._file_format = file_format self._disable_shuffling = disable_shuffling
def __init__( self, example_specs, filename_template: naming.ShardedFileTemplate, hash_salt, disable_shuffling: bool, file_format=file_adapters.DEFAULT_FILE_FORMAT, ): """Initializes Writer. Args: example_specs: spec to build ExampleSerializer. filename_template: template to format sharded filenames. hash_salt (str or bytes): salt to hash keys. disable_shuffling (bool): Specifies whether to shuffle the records. file_format (FileFormat): format of the record files in which the dataset should be written in. """ self._example_specs = example_specs self._serializer = example_serializer.ExampleSerializer(example_specs) self._shuffler = shuffle.Shuffler(dirpath=filename_template.data_dir, hash_salt=hash_salt, disable_shuffling=disable_shuffling) self._num_examples = 0 self._filename_template = filename_template self._file_format = file_format
def __init__( self, example_specs, path, hash_salt, disable_shuffling: bool, file_format=file_adapters.DEFAULT_FILE_FORMAT, ): """Initializes Writer. Args: example_specs: spec to build ExampleSerializer. path (str): path where records should be written in. hash_salt (str or bytes): salt to hash keys. disable_shuffling (bool): Specifies whether to shuffle the records. file_format (FileFormat): format of the record files in which the dataset should be written in. """ self._example_specs = example_specs self._serializer = example_serializer.ExampleSerializer(example_specs) self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt, disable_shuffling) self._num_examples = 0 self._path = path self._file_format = file_format
def __init__(self, ds_info: dataset_info.DatasetInfo, max_examples_per_shard: int, overwrite: bool = True): """Creates a SequentialWriter. Args: ds_info: DatasetInfo for this dataset. max_examples_per_shard: maximum number of examples to write per shard. overwrite: if True, it ignores and overwrites any existing data. Otherwise, it loads the existing dataset and appends the new data (new data will always be created as new shards). """ self._data_dir = ds_info.data_dir self._ds_name = ds_info.name self._ds_info = ds_info if not overwrite: try: self._ds_info.read_from_directory(self._data_dir) # read_from_directory does some checks but not on the dataset name. if self._ds_info.name != self._ds_name: raise ValueError( f'Trying to append a dataset with name {ds_info.name}' f' to an existing dataset with name {self._ds_info.name}' ) except FileNotFoundError: self._ds_info.set_file_format( file_format=file_adapters.FileFormat.TFRECORD, # if it was set, we want this to fail to warn the user override=False) else: self._ds_info.set_file_format( file_format=file_adapters.FileFormat.TFRECORD, # if it was set, we want this to fail to warn the user override=False) self._filetype_suffix = ds_info.file_format.file_suffix self._max_examples_per_shard = max_examples_per_shard self._splits = {} if not overwrite: for split_name, split in ds_info.splits.items(): self._splits[split_name] = _initialize_split( split_name=split_name, data_directory=self._data_dir, ds_name=self._ds_name, filetype_suffix=self._filetype_suffix, shard_lengths=split.shard_lengths, num_bytes=split.num_bytes) self._serializer = example_serializer.ExampleSerializer( self._ds_info.features.get_serialized_info())
def _build_from_pcollection( self, split_name: str, generator: 'beam.PCollection[KeyExample]', filename_template: naming.ShardedFileTemplate, disable_shuffling: bool, ) -> _SplitInfoFuture: """Split generator for `beam.PCollection`.""" # TODO(tfds): Should try to add support to `max_examples_per_split` beam = lazy_imports_lib.lazy_imports.apache_beam beam_writer = writer_lib.BeamWriter( serializer=example_serializer.ExampleSerializer( self._features.get_serialized_info()), filename_template=filename_template, hash_salt=split_name, disable_shuffling=disable_shuffling, file_format=self._file_format, ) def _encode_example(key_ex, encode_fn=self._features.encode_example): # We do not access self._features in this function to avoid pickling the # entire class. return key_ex[0], encode_fn(key_ex[1]) # Note: We need to wrap the pipeline in a PTransform to avoid # errors due to duplicated ``>> beam_labels` @beam.ptransform_fn def _encode_pcollection(pipeline): """PTransformation which build a single split.""" pcoll_examples = pipeline | 'Encode' >> beam.Map(_encode_example) return beam_writer.write_from_pcollection(pcoll_examples) # Add the PCollection to the pipeline _ = generator | f'{split_name}_write' >> _encode_pcollection() # pylint: disable=no-value-for-parameter def _resolve_future(): if self._in_contextmanager: raise AssertionError('`future.result()` should be called after the ' '`maybe_beam_pipeline` contextmanager.') logging.info('Retrieving split info for %s...', split_name) shard_lengths, total_size = beam_writer.finalize() return splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, filename_template=filename_template, ) return _SplitInfoFuture(_resolve_future)
def __init__(self, example_specs, path, hash_salt): """Init BeamWriter. Args: example_specs: path: str, path where to write tfrecord file. Eg: "/foo/mnist-train.tfrecord". The suffix (eg: `.00000-of-00004` will be added by the BeamWriter. Note that file "{path}.shard_lengths.json" is also created. It contains a list with the number of examples in each final shard. Eg: "[10,11,10,11]". hash_salt: string, the salt to use for hashing of keys. """ self._original_state = dict(example_specs=example_specs, path=path, hash_salt=hash_salt) self._path = path self._shards_length_path = "%s.shard_lengths.json" % path self._serializer = example_serializer.ExampleSerializer(example_specs) self._hasher = hashing.Hasher(hash_salt) self._shard_lengths = None
def __init__( self, example_specs, filename_template: naming.ShardedFileTemplate, hash_salt, disable_shuffling: bool, file_format: file_adapters.FileFormat = file_adapters. DEFAULT_FILE_FORMAT, ): """Init BeamWriter. Note that file "{filepath_prefix}.shard_lengths.json" is also created. It contains a list with the number of examples in each final shard. Eg: "[10,11,10,11]". Args: example_specs: filename_template: template to format sharded filenames. hash_salt: string, the salt to use for hashing of keys. disable_shuffling: bool, specifies whether to shuffle the records. file_format: file_adapters.FileFormat, format of the record files in which the dataset will be read/written from. """ self._original_state = dict(example_specs=example_specs, filename_template=filename_template, hash_salt=hash_salt, disable_shuffling=disable_shuffling, file_format=file_format) self._filename_template = filename_template self._split_info_path = f"{filename_template.filepath_prefix()}.split_info.json" self._serializer = example_serializer.ExampleSerializer(example_specs) self._example_specs = example_specs self._hasher = hashing.Hasher(hash_salt) self._split_info = None self._file_format = file_format self._disable_shuffling = disable_shuffling
def _build_from_generator( self, split_name: str, generator: Iterable[KeyExample], filename_template: naming.ShardedFileTemplate, disable_shuffling: bool, ) -> _SplitInfoFuture: """Split generator for example generators. Args: split_name: str, generator: Iterable[KeyExample], filename_template: Template to format the filename for a shard. disable_shuffling: Specifies whether to shuffle the examples, Returns: future: The future containing the `tfds.core.SplitInfo`. """ if self._max_examples_per_split is not None: logging.warning('Splits capped at %s examples max.', self._max_examples_per_split) generator = itertools.islice(generator, self._max_examples_per_split) total_num_examples = self._max_examples_per_split else: # If dataset info has been pre-downloaded from the internet, # we can use the pre-computed number of example for the progression bar. split_info = self._split_dict.get(split_name) if split_info and split_info.num_examples: total_num_examples = split_info.num_examples else: total_num_examples = None writer = writer_lib.Writer( serializer=example_serializer.ExampleSerializer( self._features.get_serialized_info()), filename_template=filename_template, hash_salt=split_name, disable_shuffling=disable_shuffling, # TODO(weide) remove this because it's already in filename_template? file_format=self._file_format, ) for key, example in utils.tqdm( generator, desc=f'Generating {split_name} examples...', unit=' examples', total=total_num_examples, leave=False, ): try: example = self._features.encode_example(example) except Exception as e: # pylint: disable=broad-except utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') writer.write(key, example) shard_lengths, total_size = writer.finalize() split_info = splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, filename_template=filename_template, ) return _SplitInfoFuture(lambda: split_info)
def _example_serializer(self): from tensorflow_datasets.core import example_serializer # pytype: disable=import-error # pylint: disable=g-import-not-at-top example_specs = self.get_serialized_info() return example_serializer.ExampleSerializer(example_specs)
def __init__(self, example_specs, path, hash_salt): self._serializer = example_serializer.ExampleSerializer(example_specs) self._shuffler = shuffle.Shuffler(os.path.dirname(path), hash_salt) self._num_examples = 0 self._path = path
def __init__(self, example_specs): super(TFRecordExampleAdapter, self).__init__(example_specs) self._serializer = example_serializer.ExampleSerializer( example_specs) self._parser = example_parser.ExampleParser(example_specs)