def _shuffle_tfrecord(path, random_gen): """Shuffle a single record file in memory.""" # Read all records record_iter = tf.compat.v1.io.tf_record_iterator(path) all_records = [ r for r in utils.tqdm(record_iter, desc="Reading...", unit=" examples") ] # Shuffling in memory random_gen.shuffle(all_records) # Write all record back with tf.io.TFRecordWriter(path) as writer: for record in utils.tqdm(all_records, desc="Writing...", unit=" examples"): writer.write(record)
def _compute_dynamic_properties(self, builder): """Update from the DatasetBuilder.""" # Fill other things by going over the dataset. splits = self.splits for split_info in utils.tqdm( splits.values(), desc="Computing statistics...", unit=" split"): try: split_name = split_info.name # Fill DatasetFeatureStatistics. dataset_feature_statistics, schema = get_dataset_feature_statistics( builder, split_name) # Add the statistics to this split. split_info.statistics.CopyFrom(dataset_feature_statistics) # Set the schema at the top-level since this is independent of the # split. self.as_proto.schema.CopyFrom(schema) except tf.errors.InvalidArgumentError: # This means there is no such split, even though it was specified in the # info, the least we can do is to log this. logging.error(("%s's info() property specifies split %s, but it " "doesn't seem to have been generated. Please ensure " "that the data was downloaded for this split and re-run " "download_and_prepare."), self.name, split_name) raise # Set splits to trigger proto update in setter self._set_splits(splits)
def finalize(self): """Effectively writes examples to the tfrecord files.""" print("Shuffling and writing examples to %s" % self._path) shard_specs = _get_shard_specs(self._num_examples, self._shuffler.size, self._shuffler.bucket_lengths, self._path) # Here we just loop over the examples, and don't use the instructions, just # the final number of examples in every shard. Instructions could be used to # parallelize, but one would need to be careful not to sort buckets twice. examples_generator = iter( utils.tqdm(self._shuffler, total=self._num_examples, unit=" examples", leave=False)) try: for shard_spec in shard_specs: iterator = itertools.islice(examples_generator, 0, shard_spec.examples_number) _write_tfrecord(shard_spec.path, iterator) except shuffle.DuplicatedKeysError as err: _raise_error_for_duplicated_keys(err.item1, err.item2, self._example_specs) shard_lengths = [int(spec.examples_number) for spec in shard_specs] logging.info("Done writing %s. Shard lengths: %s", self._path, shard_lengths) return shard_lengths, self._shuffler.size
def _download_and_prepare( self, dl_manager: download.DownloadManager, download_config: download.DownloadConfig, ) -> None: """Generate all splits and returns the computed split infos.""" split_builder = split_builder_lib.SplitBuilder( split_dict=self.info.splits, features=self.info.features, max_examples_per_split=download_config.max_examples_per_split, beam_options=download_config.beam_options, beam_runner=download_config.beam_runner, ) # Wrap the generation inside a context manager. # If `beam` is used during generation (when a pipeline gets created), # the context manager is equivalent to `with beam.Pipeline()`. # Otherwise, this is a no-op. with split_builder.maybe_beam_pipeline(): # If the signature has a `pipeline` kwargs, create the pipeline now and # forward it to `self._split_generators` signature = inspect.signature(self._split_generators) if "pipeline" in signature.parameters.keys(): optional_pipeline_kwargs = dict(pipeline=split_builder.beam_pipeline) else: optional_pipeline_kwargs = {} split_generators = self._split_generators( # pylint: disable=unexpected-keyword-arg dl_manager, **optional_pipeline_kwargs ) # TODO(tfds): Could be removed one all datasets are migrated. # https://github.com/tensorflow/datasets/issues/2537 # Legacy mode (eventually convert list[SplitGeneratorLegacy] -> dict) split_generators = split_builder.normalize_legacy_split_generators( split_generators=split_generators, generator_fn=self._generate_examples, is_beam=isinstance(self, BeamBasedBuilder), ) # Ensure `all` isn't used as key. _check_split_names(split_generators.keys()) # Start generating data for all splits split_info_futures = [ split_builder.submit_split_generation( # pylint: disable=g-complex-comprehension split_name=split_name, generator=generator, path=self._data_path / f"{self.name}-{split_name}.tfrecord", ) for split_name, generator in utils.tqdm(split_generators.items(), unit=" splits", leave=False) ] # Finalize the splits (after apache beam completed, if it was used) split_infos = [future.result() for future in split_info_futures] # Update the info object with the splits. # TODO(tfds): Should improve the API. split_dict = splits_lib.SplitDict(dataset_name=self.name) for split_info in split_infos: split_dict.add(split_info) self.info.update_splits_if_different(split_dict)
def _build_from_generator( self, split_name: str, generator: Iterable[KeyExample], path: type_utils.PathLike, ) -> _SplitInfoFuture: """Split generator for example generators. Args: split_name: str, generator: Iterable[KeyExample], path: type_utils.PathLike, Returns: future: The future containing the `tfds.core.SplitInfo`. """ if self._max_examples_per_split is not None: logging.warning('Splits capped at %s examples max.', self._max_examples_per_split) generator = itertools.islice(generator, self._max_examples_per_split) total_num_examples = self._max_examples_per_split else: # If dataset info has been pre-downloaded from the internet, # we can use the pre-computed number of example for the progression bar. split_info = self._split_dict.get(split_name) if split_info and split_info.num_examples: total_num_examples = split_info.num_examples else: total_num_examples = None writer = tfrecords_writer.Writer( example_specs=self._features.get_serialized_info(), path=path, hash_salt=split_name, file_format=self._file_format, ) for key, example in utils.tqdm( generator, desc=f'Generating {split_name} examples...', unit=' examples', total=total_num_examples, leave=False, ): try: example = self._features.encode_example(example) except Exception as e: # pylint: disable=broad-except utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') writer.write(key, example) shard_lengths, total_size = writer.finalize() split_info = splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, ) return _SplitInfoFuture(lambda: split_info)
def finalize(self): """Effectively writes examples to the tfrecord files.""" filename = os.path.basename(self._path) shard_specs = _get_shard_specs(self._num_examples, self._shuffler.size, self._shuffler.bucket_lengths, self._path) # Here we just loop over the examples, and don't use the instructions, just # the final number of examples in every shard. Instructions could be used to # parallelize, but one would need to be careful not to sort buckets twice. examples_generator = iter( utils.tqdm( self._shuffler, desc=f"Shuffling {filename}...", total=self._num_examples, unit=" examples", leave=False, )) try: for shard_spec in shard_specs: iterator = itertools.islice(examples_generator, 0, shard_spec.examples_number) record_keys = _write_examples(shard_spec.path, iterator, self._file_format) # No shard keys returned (e.g: TFRecord format), index cannot be # created. if not record_keys: continue # Number of `shard_keys` received should match the number of examples # written in this shard. if len(record_keys) != int(shard_spec.examples_number): raise RuntimeError( f"Length of example `keys` ({len(record_keys)}) does not match " f"`shard_spec.examples_number: (`{shard_spec.examples_number})" ) _write_index_file(shard_spec.index_path, record_keys) except shuffle.DuplicatedKeysError as err: _raise_error_for_duplicated_keys(err.item1, err.item2, self._example_specs) # Finalize the iterator to clear-up TQDM try: val = next(examples_generator) except StopIteration: pass else: raise ValueError( f"Shuffling more elements than expected. Additional element: {val}" ) shard_lengths = [int(spec.examples_number) for spec in shard_specs] logging.info("Done writing %s. Number of examples: %s (shards: %s)", filename, sum(shard_lengths), shard_lengths) return shard_lengths, self._shuffler.size
def finalize(self): """Effectively writes examples to the tfrecord files.""" print('Shuffling and writing examples to %s' % self._path) number_of_shards = _get_number_shards(self._shuffler.size, self._num_examples) writer = _TFRecordWriter(self._path, self._num_examples, number_of_shards) for serialized_example in utils.tqdm(self._shuffler, total=self._num_examples, unit=' examples', leave=False): writer.write(serialized_example) shard_lengths = writer.finalize() logging.info('Done writing %s. Shard lengths: %s', self._path, shard_lengths) return shard_lengths
def _prepare_split(self, split_generator, max_examples_per_split): generator = self._generate_examples(**split_generator.gen_kwargs) split_info = split_generator.split_info if max_examples_per_split is not None: logging.warn("Splits capped at %s examples max.", max_examples_per_split) generator = itertools.islice(generator, max_examples_per_split) if not self.version.implements(utils.Experiment.S3): return self._prepare_split_legacy(generator, split_info) fname = "{}-{}.tfrecord".format(self.name, split_generator.name) fpath = os.path.join(self._data_dir, fname) writer = tfrecords_writer.Writer(self._example_specs, fpath) for key, record in utils.tqdm(generator, unit=" examples", total=split_info.num_examples, leave=False): example = self.info.features.encode_example(record) writer.write(key, example) shard_lengths = writer.finalize() split_generator.split_info.shard_lengths.extend(shard_lengths)
def _write_tfrecords_from_generator(generator, output_files, shuffle=True): """Writes generated str records to output_files in round-robin order.""" if do_files_exist(output_files): raise ValueError( "Pre-processed files already exists: {}.".format(output_files)) with _incomplete_files(output_files) as tmp_files: # Write all shards writers = [tf.io.TFRecordWriter(fname) for fname in tmp_files] with _close_on_exit(writers) as writers: logging.info("Writing TFRecords") _round_robin_write(writers, generator) # Shuffle each shard if shuffle: # WARNING: Using np instead of Python random because Python random # produce different values between Python 2 and 3 and between # architectures random_gen = np.random.RandomState(42) for path in utils.tqdm( tmp_files, desc="Shuffling...", unit=" shard", leave=False): _shuffle_tfrecord(path, random_gen=random_gen)
def finalize(self): """Effectively writes examples to the tfrecord files.""" filename = os.path.basename(self._path) shard_specs = _get_shard_specs(self._num_examples, self._shuffler.size, self._shuffler.bucket_lengths, self._path) # Here we just loop over the examples, and don't use the instructions, just # the final number of examples in every shard. Instructions could be used to # parallelize, but one would need to be careful not to sort buckets twice. examples_generator = iter( utils.tqdm( self._shuffler, desc=f"Shuffling {filename}...", total=self._num_examples, unit=" examples", leave=False, )) try: for shard_spec in shard_specs: iterator = itertools.islice(examples_generator, 0, shard_spec.examples_number) _write_examples(shard_spec.path, iterator, self._file_format) except shuffle.DuplicatedKeysError as err: _raise_error_for_duplicated_keys(err.item1, err.item2, self._example_specs) # Finalize the iterator to clear-up TQDM try: val = next(examples_generator) except StopIteration: pass else: raise ValueError( f"Shuffling more elements than expected. Additional element: {val}" ) shard_lengths = [int(spec.examples_number) for spec in shard_specs] logging.info("Done writing %s. Number of examples: %s (shards: %s)", filename, sum(shard_lengths), shard_lengths) return shard_lengths, self._shuffler.size
def prep_imagenet_validation_data( data_dir='/project/clusterability_in_neural_networks/datasets/imagenet2012', val_tar='ILSVRC2012_img_val.tar'): # prior to running this, execute: # mkdir datasets/imagenet2012 # cd dstasets/imagenet2012 # wget [the imagenet 2012 validation tar download link] val_path = os.path.join(data_dir, val_tar) if not os.path.exists(val_path): raise FileNotFoundError( f'{val_path} does not exist. Manually download ILSVRC2012_img_val.tar \ into {data_dir} and try again.') imagenet = tfds.image.Imagenet2012() dl_manager = tfds.download.DownloadManager(download_dir=data_dir) arch = dl_manager.iter_archive(val_path) val_gen = tfds.core.SplitGenerator( name=tfds.Split.VALIDATION, gen_kwargs={ 'archive': arch, 'validation_labels': imagenet._get_validation_labels(val_path) }) validation_labels = imagenet._get_validation_labels(val_path) main_gen = imagenet._generate_examples_validation(archive=arch, labels=validation_labels) fname = "{}-{}.tfrecord".format('imagenet2012', val_gen.name) fpath = os.path.join(data_dir, fname) writer = tfrecords_writer.Writer(imagenet._example_specs, fpath, hash_salt=val_gen.name) for key, record in tfds_utils.tqdm(main_gen, unit=" examples", total=val_gen.split_info.num_examples, leave=False): example = imagenet.info.features.encode_example(record) writer.write(key, example) _, _ = writer.finalize()
def _round_robin_write(writers, generator): """Write records from generator round-robin across writers.""" for i, example in enumerate( utils.tqdm(generator, unit=" examples", leave=False)): writers[i % len(writers)].write(example)
def get_dataset_feature_statistics(builder, split): """Calculate statistics for the specified split.""" statistics = statistics_pb2.DatasetFeatureStatistics() # Make this to the best of our abilities. schema = schema_pb2.Schema() dataset = builder.as_dataset(split=split) # Just computing the number of examples for now. statistics.num_examples = 0 # Feature dictionaries. feature_to_num_examples = collections.defaultdict(int) feature_to_min = {} feature_to_max = {} np_dataset = dataset_utils.as_numpy(dataset) for example in utils.tqdm(np_dataset, unit=" examples", leave=False): statistics.num_examples += 1 assert isinstance(example, dict) feature_names = sorted(example.keys()) for feature_name in feature_names: # Update the number of examples this feature appears in. feature_to_num_examples[feature_name] += 1 feature_np = example[feature_name] # For compatibility in graph and eager mode, we can get PODs here and # everything may not be neatly wrapped up in numpy's ndarray. feature_dtype = type(feature_np) if isinstance(feature_np, np.ndarray): # If we have an empty array, then don't proceed further with computing # statistics on it. if feature_np.size == 0: continue feature_dtype = feature_np.dtype.type feature_min, feature_max = None, None is_numeric = (np.issubdtype(feature_dtype, np.number) or feature_dtype == np.bool_) if is_numeric: feature_min = np.min(feature_np) feature_max = np.max(feature_np) # TODO(afrozm): What if shapes don't match? Populate ValueCount? Add # logic for that. # Set or update the min, max. if is_numeric: if ((feature_name not in feature_to_min) or (feature_to_min[feature_name] > feature_min)): feature_to_min[feature_name] = feature_min if ((feature_name not in feature_to_max) or (feature_to_max[feature_name] < feature_max)): feature_to_max[feature_name] = feature_max # Start here, we've processed all examples. output_shapes_dict = dataset.output_shapes output_types_dict = dataset.output_types for feature_name in sorted(feature_to_num_examples.keys()): # Try to fill in the schema. feature = schema.feature.add() feature.name = feature_name # TODO(afrozm): Make this work with nested structures, currently the Schema # proto has no support for it. maybe_feature_shape = output_shapes_dict[feature_name] if not isinstance(maybe_feature_shape, tf.TensorShape): logging.error( "Statistics generation doesn't work for nested structures yet") continue for dim in maybe_feature_shape.as_list(): # We denote `None`s as -1 in the shape proto. feature.shape.dim.add().size = dim if dim else -1 feature_type = output_types_dict[feature_name] feature.type = _FEATURE_TYPE_MAP.get(feature_type, schema_pb2.BYTES) common_statistics = statistics_pb2.CommonStatistics() common_statistics.num_non_missing = feature_to_num_examples[feature_name] common_statistics.num_missing = ( statistics.num_examples - common_statistics.num_non_missing) feature_name_statistics = statistics.features.add() feature_name_statistics.name = feature_name # TODO(afrozm): This can be skipped, since type information was added to # the Schema. feature_name_statistics.type = _SCHEMA_TYPE_MAP.get( feature.type, statistics_pb2.FeatureNameStatistics.BYTES) if feature.type == schema_pb2.INT or feature.type == schema_pb2.FLOAT: numeric_statistics = statistics_pb2.NumericStatistics() # Uses `.get` as Sequence(int) containing only empty array won't contains # any value. numeric_statistics.min = feature_to_min.get(feature_name, 0) numeric_statistics.max = feature_to_max.get(feature_name, 0) numeric_statistics.common_stats.CopyFrom(common_statistics) feature_name_statistics.num_stats.CopyFrom(numeric_statistics) else: # Let's shove it into BytesStatistics for now. bytes_statistics = statistics_pb2.BytesStatistics() bytes_statistics.common_stats.CopyFrom(common_statistics) feature_name_statistics.bytes_stats.CopyFrom(bytes_statistics) return statistics, schema
def _build_from_generator( self, split_name: str, generator: Iterable[KeyExample], filename_template: naming.ShardedFileTemplate, disable_shuffling: bool, ) -> _SplitInfoFuture: """Split generator for example generators. Args: split_name: str, generator: Iterable[KeyExample], filename_template: Template to format the filename for a shard. disable_shuffling: Specifies whether to shuffle the examples, Returns: future: The future containing the `tfds.core.SplitInfo`. """ if self._max_examples_per_split is not None: logging.warning('Splits capped at %s examples max.', self._max_examples_per_split) generator = itertools.islice(generator, self._max_examples_per_split) total_num_examples = self._max_examples_per_split else: # If dataset info has been pre-downloaded from the internet, # we can use the pre-computed number of example for the progression bar. split_info = self._split_dict.get(split_name) if split_info and split_info.num_examples: total_num_examples = split_info.num_examples else: total_num_examples = None writer = writer_lib.Writer( serializer=example_serializer.ExampleSerializer( self._features.get_serialized_info()), filename_template=filename_template, hash_salt=split_name, disable_shuffling=disable_shuffling, # TODO(weide) remove this because it's already in filename_template? file_format=self._file_format, ) for key, example in utils.tqdm( generator, desc=f'Generating {split_name} examples...', unit=' examples', total=total_num_examples, leave=False, ): try: example = self._features.encode_example(example) except Exception as e: # pylint: disable=broad-except utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') writer.write(key, example) shard_lengths, total_size = writer.finalize() split_info = splits_lib.SplitInfo( name=split_name, shard_lengths=shard_lengths, num_bytes=total_size, filename_template=filename_template, ) return _SplitInfoFuture(lambda: split_info)
def _download_and_prepare( self, dl_manager: download.DownloadManager, download_config: download.DownloadConfig, ) -> None: """Generate all splits and returns the computed split infos.""" split_builder = split_builder_lib.SplitBuilder( split_dict=self.info.splits, features=self.info.features, max_examples_per_split=download_config.max_examples_per_split, beam_options=download_config.beam_options, beam_runner=download_config.beam_runner, file_format=self._file_format, ) # Wrap the generation inside a context manager. # If `beam` is used during generation (when a pipeline gets created), # the context manager is equivalent to `with beam.Pipeline()`. # Otherwise, this is a no-op. # By auto-detecting Beam, the user only has to change `_generate_examples` # to go from non-beam to beam dataset: # https://www.tensorflow.org/datasets/beam_datasets#instructions with split_builder.maybe_beam_pipeline(): # If the signature has a `pipeline` kwargs, create the pipeline now and # forward it to `self._split_generators` # We add this magic because the pipeline kwargs is only used by c4 and # we do not want to make the API more verbose for a single advanced case. signature = inspect.signature(self._split_generators) if "pipeline" in signature.parameters.keys(): optional_pipeline_kwargs = dict(pipeline=split_builder.beam_pipeline) else: optional_pipeline_kwargs = {} split_generators = self._split_generators( # pylint: disable=unexpected-keyword-arg dl_manager, **optional_pipeline_kwargs ) # TODO(tfds): Could be removed once all datasets are migrated. # https://github.com/tensorflow/datasets/issues/2537 # Legacy mode (eventually convert list[SplitGeneratorLegacy] -> dict) split_generators = split_builder.normalize_legacy_split_generators( split_generators=split_generators, generator_fn=self._generate_examples, is_beam=isinstance(self, BeamBasedBuilder), ) # Ensure `all` isn't used as key. _check_split_names(split_generators.keys()) # Writer fail if the number of example yield is `0`, so we return here. if download_config.max_examples_per_split == 0: return # Start generating data for all splits path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ self._file_format].FILE_SUFFIX split_info_futures = [ split_builder.submit_split_generation( # pylint: disable=g-complex-comprehension split_name=split_name, generator=generator, path=self.data_path / f"{self.name}-{split_name}.{path_suffix}", ) for split_name, generator in utils.tqdm(split_generators.items(), unit=" splits", leave=False) ] # Finalize the splits (after apache beam completed, if it was used) split_infos = [future.result() for future in split_info_futures] # Update the info object with the splits. split_dict = splits_lib.SplitDict(split_infos, dataset_name=self.name) self.info.set_splits(split_dict)