예제 #1
0
def _shuffle_tfrecord(path, random_gen):
    """Shuffle a single record file in memory."""
    # Read all records
    record_iter = tf.compat.v1.io.tf_record_iterator(path)
    all_records = [
        r for r in utils.tqdm(record_iter, desc="Reading...", unit=" examples")
    ]
    # Shuffling in memory
    random_gen.shuffle(all_records)
    # Write all record back
    with tf.io.TFRecordWriter(path) as writer:
        for record in utils.tqdm(all_records,
                                 desc="Writing...",
                                 unit=" examples"):
            writer.write(record)
예제 #2
0
  def _compute_dynamic_properties(self, builder):
    """Update from the DatasetBuilder."""
    # Fill other things by going over the dataset.
    splits = self.splits
    for split_info in utils.tqdm(
        splits.values(), desc="Computing statistics...", unit=" split"):
      try:
        split_name = split_info.name
        # Fill DatasetFeatureStatistics.
        dataset_feature_statistics, schema = get_dataset_feature_statistics(
            builder, split_name)

        # Add the statistics to this split.
        split_info.statistics.CopyFrom(dataset_feature_statistics)

        # Set the schema at the top-level since this is independent of the
        # split.
        self.as_proto.schema.CopyFrom(schema)

      except tf.errors.InvalidArgumentError:
        # This means there is no such split, even though it was specified in the
        # info, the least we can do is to log this.
        logging.error(("%s's info() property specifies split %s, but it "
                       "doesn't seem to have been generated. Please ensure "
                       "that the data was downloaded for this split and re-run "
                       "download_and_prepare."), self.name, split_name)
        raise

    # Set splits to trigger proto update in setter
    self._set_splits(splits)
 def finalize(self):
     """Effectively writes examples to the tfrecord files."""
     print("Shuffling and writing examples to %s" % self._path)
     shard_specs = _get_shard_specs(self._num_examples, self._shuffler.size,
                                    self._shuffler.bucket_lengths,
                                    self._path)
     # Here we just loop over the examples, and don't use the instructions, just
     # the final number of examples in every shard. Instructions could be used to
     # parallelize, but one would need to be careful not to sort buckets twice.
     examples_generator = iter(
         utils.tqdm(self._shuffler,
                    total=self._num_examples,
                    unit=" examples",
                    leave=False))
     try:
         for shard_spec in shard_specs:
             iterator = itertools.islice(examples_generator, 0,
                                         shard_spec.examples_number)
             _write_tfrecord(shard_spec.path, iterator)
     except shuffle.DuplicatedKeysError as err:
         _raise_error_for_duplicated_keys(err.item1, err.item2,
                                          self._example_specs)
     shard_lengths = [int(spec.examples_number) for spec in shard_specs]
     logging.info("Done writing %s. Shard lengths: %s", self._path,
                  shard_lengths)
     return shard_lengths, self._shuffler.size
예제 #4
0
  def _download_and_prepare(
      self,
      dl_manager: download.DownloadManager,
      download_config: download.DownloadConfig,
  ) -> None:
    """Generate all splits and returns the computed split infos."""
    split_builder = split_builder_lib.SplitBuilder(
        split_dict=self.info.splits,
        features=self.info.features,
        max_examples_per_split=download_config.max_examples_per_split,
        beam_options=download_config.beam_options,
        beam_runner=download_config.beam_runner,
    )
    # Wrap the generation inside a context manager.
    # If `beam` is used during generation (when a pipeline gets created),
    # the context manager is equivalent to `with beam.Pipeline()`.
    # Otherwise, this is a no-op.
    with split_builder.maybe_beam_pipeline():
      # If the signature has a `pipeline` kwargs, create the pipeline now and
      # forward it to `self._split_generators`
      signature = inspect.signature(self._split_generators)
      if "pipeline" in signature.parameters.keys():
        optional_pipeline_kwargs = dict(pipeline=split_builder.beam_pipeline)
      else:
        optional_pipeline_kwargs = {}
      split_generators = self._split_generators(  # pylint: disable=unexpected-keyword-arg
          dl_manager, **optional_pipeline_kwargs
      )
      # TODO(tfds): Could be removed one all datasets are migrated.
      # https://github.com/tensorflow/datasets/issues/2537
      # Legacy mode (eventually convert list[SplitGeneratorLegacy] -> dict)
      split_generators = split_builder.normalize_legacy_split_generators(
          split_generators=split_generators,
          generator_fn=self._generate_examples,
          is_beam=isinstance(self, BeamBasedBuilder),
      )

      # Ensure `all` isn't used as key.
      _check_split_names(split_generators.keys())

      # Start generating data for all splits
      split_info_futures = [
          split_builder.submit_split_generation(  # pylint: disable=g-complex-comprehension
              split_name=split_name,
              generator=generator,
              path=self._data_path / f"{self.name}-{split_name}.tfrecord",
          )
          for split_name, generator
          in utils.tqdm(split_generators.items(), unit=" splits", leave=False)
      ]
    # Finalize the splits (after apache beam completed, if it was used)
    split_infos = [future.result() for future in split_info_futures]

    # Update the info object with the splits.
    # TODO(tfds): Should improve the API.
    split_dict = splits_lib.SplitDict(dataset_name=self.name)
    for split_info in split_infos:
      split_dict.add(split_info)
    self.info.update_splits_if_different(split_dict)
예제 #5
0
    def _build_from_generator(
        self,
        split_name: str,
        generator: Iterable[KeyExample],
        path: type_utils.PathLike,
    ) -> _SplitInfoFuture:
        """Split generator for example generators.

    Args:
      split_name: str,
      generator: Iterable[KeyExample],
      path: type_utils.PathLike,

    Returns:
      future: The future containing the `tfds.core.SplitInfo`.
    """
        if self._max_examples_per_split is not None:
            logging.warning('Splits capped at %s examples max.',
                            self._max_examples_per_split)
            generator = itertools.islice(generator,
                                         self._max_examples_per_split)
            total_num_examples = self._max_examples_per_split
        else:
            # If dataset info has been pre-downloaded from the internet,
            # we can use the pre-computed number of example for the progression bar.
            split_info = self._split_dict.get(split_name)
            if split_info and split_info.num_examples:
                total_num_examples = split_info.num_examples
            else:
                total_num_examples = None

        writer = tfrecords_writer.Writer(
            example_specs=self._features.get_serialized_info(),
            path=path,
            hash_salt=split_name,
            file_format=self._file_format,
        )
        for key, example in utils.tqdm(
                generator,
                desc=f'Generating {split_name} examples...',
                unit=' examples',
                total=total_num_examples,
                leave=False,
        ):
            try:
                example = self._features.encode_example(example)
            except Exception as e:  # pylint: disable=broad-except
                utils.reraise(e,
                              prefix=f'Failed to encode example:\n{example}\n')
            writer.write(key, example)
        shard_lengths, total_size = writer.finalize()

        split_info = splits_lib.SplitInfo(
            name=split_name,
            shard_lengths=shard_lengths,
            num_bytes=total_size,
        )
        return _SplitInfoFuture(lambda: split_info)
예제 #6
0
    def finalize(self):
        """Effectively writes examples to the tfrecord files."""
        filename = os.path.basename(self._path)
        shard_specs = _get_shard_specs(self._num_examples, self._shuffler.size,
                                       self._shuffler.bucket_lengths,
                                       self._path)
        # Here we just loop over the examples, and don't use the instructions, just
        # the final number of examples in every shard. Instructions could be used to
        # parallelize, but one would need to be careful not to sort buckets twice.
        examples_generator = iter(
            utils.tqdm(
                self._shuffler,
                desc=f"Shuffling {filename}...",
                total=self._num_examples,
                unit=" examples",
                leave=False,
            ))
        try:
            for shard_spec in shard_specs:
                iterator = itertools.islice(examples_generator, 0,
                                            shard_spec.examples_number)
                record_keys = _write_examples(shard_spec.path, iterator,
                                              self._file_format)

                # No shard keys returned (e.g: TFRecord format), index cannot be
                # created.
                if not record_keys:
                    continue

                # Number of `shard_keys` received should match the number of examples
                # written in this shard.
                if len(record_keys) != int(shard_spec.examples_number):
                    raise RuntimeError(
                        f"Length of example `keys` ({len(record_keys)}) does not match "
                        f"`shard_spec.examples_number: (`{shard_spec.examples_number})"
                    )
                _write_index_file(shard_spec.index_path, record_keys)

        except shuffle.DuplicatedKeysError as err:
            _raise_error_for_duplicated_keys(err.item1, err.item2,
                                             self._example_specs)

        # Finalize the iterator to clear-up TQDM
        try:
            val = next(examples_generator)
        except StopIteration:
            pass
        else:
            raise ValueError(
                f"Shuffling more elements than expected. Additional element: {val}"
            )

        shard_lengths = [int(spec.examples_number) for spec in shard_specs]
        logging.info("Done writing %s. Number of examples: %s (shards: %s)",
                     filename, sum(shard_lengths), shard_lengths)
        return shard_lengths, self._shuffler.size
예제 #7
0
 def finalize(self):
     """Effectively writes examples to the tfrecord files."""
     print('Shuffling and writing examples to %s' % self._path)
     number_of_shards = _get_number_shards(self._shuffler.size,
                                           self._num_examples)
     writer = _TFRecordWriter(self._path, self._num_examples,
                              number_of_shards)
     for serialized_example in utils.tqdm(self._shuffler,
                                          total=self._num_examples,
                                          unit=' examples',
                                          leave=False):
         writer.write(serialized_example)
     shard_lengths = writer.finalize()
     logging.info('Done writing %s. Shard lengths: %s', self._path,
                  shard_lengths)
     return shard_lengths
예제 #8
0
 def _prepare_split(self, split_generator, max_examples_per_split):
   generator = self._generate_examples(**split_generator.gen_kwargs)
   split_info = split_generator.split_info
   if max_examples_per_split is not None:
     logging.warn("Splits capped at %s examples max.", max_examples_per_split)
     generator = itertools.islice(generator, max_examples_per_split)
   if not self.version.implements(utils.Experiment.S3):
     return self._prepare_split_legacy(generator, split_info)
   fname = "{}-{}.tfrecord".format(self.name, split_generator.name)
   fpath = os.path.join(self._data_dir, fname)
   writer = tfrecords_writer.Writer(self._example_specs, fpath)
   for key, record in utils.tqdm(generator, unit=" examples",
                                 total=split_info.num_examples, leave=False):
     example = self.info.features.encode_example(record)
     writer.write(key, example)
   shard_lengths = writer.finalize()
   split_generator.split_info.shard_lengths.extend(shard_lengths)
예제 #9
0
def _write_tfrecords_from_generator(generator, output_files, shuffle=True):
  """Writes generated str records to output_files in round-robin order."""
  if do_files_exist(output_files):
    raise ValueError(
        "Pre-processed files already exists: {}.".format(output_files))

  with _incomplete_files(output_files) as tmp_files:
    # Write all shards
    writers = [tf.io.TFRecordWriter(fname) for fname in tmp_files]
    with _close_on_exit(writers) as writers:
      logging.info("Writing TFRecords")
      _round_robin_write(writers, generator)
    # Shuffle each shard
    if shuffle:
      # WARNING: Using np instead of Python random because Python random
      # produce different values between Python 2 and 3 and between
      # architectures
      random_gen = np.random.RandomState(42)
      for path in utils.tqdm(
          tmp_files, desc="Shuffling...", unit=" shard", leave=False):
        _shuffle_tfrecord(path, random_gen=random_gen)
    def finalize(self):
        """Effectively writes examples to the tfrecord files."""
        filename = os.path.basename(self._path)
        shard_specs = _get_shard_specs(self._num_examples, self._shuffler.size,
                                       self._shuffler.bucket_lengths,
                                       self._path)
        # Here we just loop over the examples, and don't use the instructions, just
        # the final number of examples in every shard. Instructions could be used to
        # parallelize, but one would need to be careful not to sort buckets twice.
        examples_generator = iter(
            utils.tqdm(
                self._shuffler,
                desc=f"Shuffling {filename}...",
                total=self._num_examples,
                unit=" examples",
                leave=False,
            ))
        try:
            for shard_spec in shard_specs:
                iterator = itertools.islice(examples_generator, 0,
                                            shard_spec.examples_number)
                _write_examples(shard_spec.path, iterator, self._file_format)
        except shuffle.DuplicatedKeysError as err:
            _raise_error_for_duplicated_keys(err.item1, err.item2,
                                             self._example_specs)

        # Finalize the iterator to clear-up TQDM
        try:
            val = next(examples_generator)
        except StopIteration:
            pass
        else:
            raise ValueError(
                f"Shuffling more elements than expected. Additional element: {val}"
            )

        shard_lengths = [int(spec.examples_number) for spec in shard_specs]
        logging.info("Done writing %s. Number of examples: %s (shards: %s)",
                     filename, sum(shard_lengths), shard_lengths)
        return shard_lengths, self._shuffler.size
예제 #11
0
def prep_imagenet_validation_data(
        data_dir='/project/clusterability_in_neural_networks/datasets/imagenet2012',
        val_tar='ILSVRC2012_img_val.tar'):

    # prior to running this, execute:
    # mkdir datasets/imagenet2012
    # cd dstasets/imagenet2012
    # wget [the imagenet 2012 validation tar download link]

    val_path = os.path.join(data_dir, val_tar)

    if not os.path.exists(val_path):
        raise FileNotFoundError(
            f'{val_path} does not exist. Manually download ILSVRC2012_img_val.tar \
        into {data_dir} and try again.')

    imagenet = tfds.image.Imagenet2012()
    dl_manager = tfds.download.DownloadManager(download_dir=data_dir)
    arch = dl_manager.iter_archive(val_path)
    val_gen = tfds.core.SplitGenerator(
        name=tfds.Split.VALIDATION,
        gen_kwargs={
            'archive': arch,
            'validation_labels': imagenet._get_validation_labels(val_path)
        })
    validation_labels = imagenet._get_validation_labels(val_path)
    main_gen = imagenet._generate_examples_validation(archive=arch,
                                                      labels=validation_labels)
    fname = "{}-{}.tfrecord".format('imagenet2012', val_gen.name)
    fpath = os.path.join(data_dir, fname)
    writer = tfrecords_writer.Writer(imagenet._example_specs,
                                     fpath,
                                     hash_salt=val_gen.name)
    for key, record in tfds_utils.tqdm(main_gen,
                                       unit=" examples",
                                       total=val_gen.split_info.num_examples,
                                       leave=False):
        example = imagenet.info.features.encode_example(record)
        writer.write(key, example)
    _, _ = writer.finalize()
예제 #12
0
def _round_robin_write(writers, generator):
    """Write records from generator round-robin across writers."""
    for i, example in enumerate(
            utils.tqdm(generator, unit=" examples", leave=False)):
        writers[i % len(writers)].write(example)
예제 #13
0
def get_dataset_feature_statistics(builder, split):
  """Calculate statistics for the specified split."""
  statistics = statistics_pb2.DatasetFeatureStatistics()

  # Make this to the best of our abilities.
  schema = schema_pb2.Schema()

  dataset = builder.as_dataset(split=split)

  # Just computing the number of examples for now.
  statistics.num_examples = 0

  # Feature dictionaries.
  feature_to_num_examples = collections.defaultdict(int)
  feature_to_min = {}
  feature_to_max = {}

  np_dataset = dataset_utils.as_numpy(dataset)
  for example in utils.tqdm(np_dataset, unit=" examples", leave=False):
    statistics.num_examples += 1

    assert isinstance(example, dict)

    feature_names = sorted(example.keys())
    for feature_name in feature_names:

      # Update the number of examples this feature appears in.
      feature_to_num_examples[feature_name] += 1

      feature_np = example[feature_name]

      # For compatibility in graph and eager mode, we can get PODs here and
      # everything may not be neatly wrapped up in numpy's ndarray.

      feature_dtype = type(feature_np)

      if isinstance(feature_np, np.ndarray):
        # If we have an empty array, then don't proceed further with computing
        # statistics on it.
        if feature_np.size == 0:
          continue

        feature_dtype = feature_np.dtype.type

      feature_min, feature_max = None, None
      is_numeric = (np.issubdtype(feature_dtype, np.number) or
                    feature_dtype == np.bool_)
      if is_numeric:
        feature_min = np.min(feature_np)
        feature_max = np.max(feature_np)

      # TODO(afrozm): What if shapes don't match? Populate ValueCount? Add
      # logic for that.

      # Set or update the min, max.
      if is_numeric:
        if ((feature_name not in feature_to_min) or
            (feature_to_min[feature_name] > feature_min)):
          feature_to_min[feature_name] = feature_min

        if ((feature_name not in feature_to_max) or
            (feature_to_max[feature_name] < feature_max)):
          feature_to_max[feature_name] = feature_max

  # Start here, we've processed all examples.

  output_shapes_dict = dataset.output_shapes
  output_types_dict = dataset.output_types

  for feature_name in sorted(feature_to_num_examples.keys()):
    # Try to fill in the schema.
    feature = schema.feature.add()
    feature.name = feature_name

    # TODO(afrozm): Make this work with nested structures, currently the Schema
    # proto has no support for it.
    maybe_feature_shape = output_shapes_dict[feature_name]
    if not isinstance(maybe_feature_shape, tf.TensorShape):
      logging.error(
          "Statistics generation doesn't work for nested structures yet")
      continue

    for dim in maybe_feature_shape.as_list():
      # We denote `None`s as -1 in the shape proto.
      feature.shape.dim.add().size = dim if dim else -1
    feature_type = output_types_dict[feature_name]
    feature.type = _FEATURE_TYPE_MAP.get(feature_type, schema_pb2.BYTES)

    common_statistics = statistics_pb2.CommonStatistics()
    common_statistics.num_non_missing = feature_to_num_examples[feature_name]
    common_statistics.num_missing = (
        statistics.num_examples - common_statistics.num_non_missing)

    feature_name_statistics = statistics.features.add()
    feature_name_statistics.name = feature_name

    # TODO(afrozm): This can be skipped, since type information was added to
    # the Schema.
    feature_name_statistics.type = _SCHEMA_TYPE_MAP.get(
        feature.type, statistics_pb2.FeatureNameStatistics.BYTES)

    if feature.type == schema_pb2.INT or feature.type == schema_pb2.FLOAT:
      numeric_statistics = statistics_pb2.NumericStatistics()
      # Uses `.get` as Sequence(int) containing only empty array won't contains
      # any value.
      numeric_statistics.min = feature_to_min.get(feature_name, 0)
      numeric_statistics.max = feature_to_max.get(feature_name, 0)
      numeric_statistics.common_stats.CopyFrom(common_statistics)
      feature_name_statistics.num_stats.CopyFrom(numeric_statistics)
    else:
      # Let's shove it into BytesStatistics for now.
      bytes_statistics = statistics_pb2.BytesStatistics()
      bytes_statistics.common_stats.CopyFrom(common_statistics)
      feature_name_statistics.bytes_stats.CopyFrom(bytes_statistics)

  return statistics, schema
예제 #14
0
  def _build_from_generator(
      self,
      split_name: str,
      generator: Iterable[KeyExample],
      filename_template: naming.ShardedFileTemplate,
      disable_shuffling: bool,
  ) -> _SplitInfoFuture:
    """Split generator for example generators.

    Args:
      split_name: str,
      generator: Iterable[KeyExample],
      filename_template: Template to format the filename for a shard.
      disable_shuffling: Specifies whether to shuffle the examples,

    Returns:
      future: The future containing the `tfds.core.SplitInfo`.
    """
    if self._max_examples_per_split is not None:
      logging.warning('Splits capped at %s examples max.',
                      self._max_examples_per_split)
      generator = itertools.islice(generator, self._max_examples_per_split)
      total_num_examples = self._max_examples_per_split
    else:
      # If dataset info has been pre-downloaded from the internet,
      # we can use the pre-computed number of example for the progression bar.
      split_info = self._split_dict.get(split_name)
      if split_info and split_info.num_examples:
        total_num_examples = split_info.num_examples
      else:
        total_num_examples = None

    writer = writer_lib.Writer(
        serializer=example_serializer.ExampleSerializer(
            self._features.get_serialized_info()),
        filename_template=filename_template,
        hash_salt=split_name,
        disable_shuffling=disable_shuffling,
        # TODO(weide) remove this because it's already in filename_template?
        file_format=self._file_format,
    )
    for key, example in utils.tqdm(
        generator,
        desc=f'Generating {split_name} examples...',
        unit=' examples',
        total=total_num_examples,
        leave=False,
    ):
      try:
        example = self._features.encode_example(example)
      except Exception as e:  # pylint: disable=broad-except
        utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n')
      writer.write(key, example)
    shard_lengths, total_size = writer.finalize()

    split_info = splits_lib.SplitInfo(
        name=split_name,
        shard_lengths=shard_lengths,
        num_bytes=total_size,
        filename_template=filename_template,
    )
    return _SplitInfoFuture(lambda: split_info)
예제 #15
0
  def _download_and_prepare(
      self,
      dl_manager: download.DownloadManager,
      download_config: download.DownloadConfig,
  ) -> None:
    """Generate all splits and returns the computed split infos."""
    split_builder = split_builder_lib.SplitBuilder(
        split_dict=self.info.splits,
        features=self.info.features,
        max_examples_per_split=download_config.max_examples_per_split,
        beam_options=download_config.beam_options,
        beam_runner=download_config.beam_runner,
        file_format=self._file_format,
    )
    # Wrap the generation inside a context manager.
    # If `beam` is used during generation (when a pipeline gets created),
    # the context manager is equivalent to `with beam.Pipeline()`.
    # Otherwise, this is a no-op.
    # By auto-detecting Beam, the user only has to change `_generate_examples`
    # to go from non-beam to beam dataset:
    # https://www.tensorflow.org/datasets/beam_datasets#instructions
    with split_builder.maybe_beam_pipeline():
      # If the signature has a `pipeline` kwargs, create the pipeline now and
      # forward it to `self._split_generators`
      # We add this magic because the pipeline kwargs is only used by c4 and
      # we do not want to make the API more verbose for a single advanced case.
      signature = inspect.signature(self._split_generators)
      if "pipeline" in signature.parameters.keys():
        optional_pipeline_kwargs = dict(pipeline=split_builder.beam_pipeline)
      else:
        optional_pipeline_kwargs = {}
      split_generators = self._split_generators(  # pylint: disable=unexpected-keyword-arg
          dl_manager, **optional_pipeline_kwargs
      )
      # TODO(tfds): Could be removed once all datasets are migrated.
      # https://github.com/tensorflow/datasets/issues/2537
      # Legacy mode (eventually convert list[SplitGeneratorLegacy] -> dict)
      split_generators = split_builder.normalize_legacy_split_generators(
          split_generators=split_generators,
          generator_fn=self._generate_examples,
          is_beam=isinstance(self, BeamBasedBuilder),
      )

      # Ensure `all` isn't used as key.
      _check_split_names(split_generators.keys())

      # Writer fail if the number of example yield is `0`, so we return here.
      if download_config.max_examples_per_split == 0:
        return

      # Start generating data for all splits
      path_suffix = file_adapters.ADAPTER_FOR_FORMAT[
          self._file_format].FILE_SUFFIX

      split_info_futures = [
          split_builder.submit_split_generation(  # pylint: disable=g-complex-comprehension
              split_name=split_name,
              generator=generator,
              path=self.data_path / f"{self.name}-{split_name}.{path_suffix}",
          )
          for split_name, generator
          in utils.tqdm(split_generators.items(), unit=" splits", leave=False)
      ]
    # Finalize the splits (after apache beam completed, if it was used)
    split_infos = [future.result() for future in split_info_futures]

    # Update the info object with the splits.
    split_dict = splits_lib.SplitDict(split_infos, dataset_name=self.name)
    self.info.set_splits(split_dict)