Esempio n. 1
0
  def _write(
      self,
      to_write,
      salt: str = '',
      dataset_name: str = 'foo',
      split: str = 'train',
      disable_shuffling=False,
      file_format=file_adapters.DEFAULT_FILE_FORMAT,
  ):
    filetype_suffix = file_adapters.ADAPTER_FOR_FORMAT[file_format].FILE_SUFFIX
    filename_template = naming.ShardedFileTemplate(
        dataset_name=dataset_name,
        split=split,
        filetype_suffix=filetype_suffix,
        data_dir=self.tmp_dir)
    beam = lazy_imports_lib.lazy_imports.apache_beam
    writer = writer_lib.BeamWriter(
        serializer=testing.DummySerializer('dummy specs'),
        filename_template=filename_template,
        hash_salt=salt,
        disable_shuffling=disable_shuffling,
        file_format=file_format,
    )
    # Here we need to disable type check as `beam.Create` is not capable of
    # inferring the type of the PCollection elements.
    options = beam.options.pipeline_options.PipelineOptions(
        pipeline_type_check=False)
    with beam.Pipeline(options=options) as pipeline:

      @beam.ptransform_fn
      def _build_pcollection(pipeline):
        pcollection = pipeline | 'Start' >> beam.Create(to_write)
        return writer.write_from_pcollection(pcollection)

      _ = pipeline | 'test' >> _build_pcollection()  # pylint: disable=no-value-for-parameter
    return writer.finalize()
Esempio n. 2
0
 def _write(
     self,
     to_write,
     salt: str = '',
     dataset_name: str = 'foo',
     split: str = 'train',
     disable_shuffling=False,
     file_format=file_adapters.DEFAULT_FILE_FORMAT,
 ):
   filetype_suffix = file_adapters.ADAPTER_FOR_FORMAT[file_format].FILE_SUFFIX
   filename_template = naming.ShardedFileTemplate(
       dataset_name=dataset_name,
       split=split,
       filetype_suffix=filetype_suffix,
       data_dir=self.tmp_dir)
   writer = writer_lib.Writer(
       serializer=testing.DummySerializer('dummy specs'),
       filename_template=filename_template,
       hash_salt=salt,
       disable_shuffling=disable_shuffling,
       file_format=file_format)
   for key, record in to_write:
     writer.write(key, record)
   return writer.finalize()
Esempio n. 3
0
 def _filename_template(self, split: str) -> naming.ShardedFileTemplate:
     return naming.ShardedFileTemplate(dataset_name='mnist',
                                       split=split,
                                       filetype_suffix='tfrecord',
                                       data_dir=self.tmp_dir)
Esempio n. 4
0
def test_sharded_file_template_no_template_incomplete():
  builder_dir = epath.Path('/my/path')
  template_without_split = naming.ShardedFileTemplate(
      data_dir=builder_dir, dataset_name='imagenet', filetype_suffix='riegeli')
  with pytest.raises(KeyError):
    template_without_split.sharded_filepath(shard_index=12, num_shards=100)
Esempio n. 5
0
def test_sharded_file_template_empty_filetype_suffix():
  with pytest.raises(
      ValueError, match='Filetype suffix must be a non-empty string: .+'):
    naming.ShardedFileTemplate(
        data_dir='/path', dataset_name='mnist', filetype_suffix='')
Esempio n. 6
0
    def read_from_directory(self, dataset_info_dir: str) -> None:
        """Update DatasetInfo from the JSON files in `dataset_info_dir`.

    This function updates all the dynamically generated fields (num_examples,
    hash, time of creation,...) of the DatasetInfo.

    This will overwrite all previous metadata.

    Args:
      dataset_info_dir: `str` The directory containing the metadata file. This
        should be the root directory of a specific dataset version.

    Raises:
      FileNotFoundError: If the dataset_info.json can't be found.
    """
        logging.info("Load dataset info from %s", dataset_info_dir)

        json_filename = dataset_info_path(dataset_info_dir)
        if not tf.io.gfile.exists(json_filename):
            raise FileNotFoundError(
                "Tried to load `DatasetInfo` from a directory which does not exist or"
                " does not contain `dataset_info.json`. Please delete the directory "
                f"`{dataset_info_dir}`  if you are trying to re-generate the "
                "dataset.")

        # Load the metadata from disk
        parsed_proto = read_from_json(json_filename)

        if str(self.version) != parsed_proto.version:
            raise AssertionError(
                "The constructed DatasetInfo instance and the restored proto version "
                "do not match. Builder version: {}. Proto version: {}".format(
                    self.version, parsed_proto.version))

        self._identity = DatasetIdentity.from_proto(info_proto=parsed_proto,
                                                    data_dir=dataset_info_dir)

        # Update splits
        filename_template = naming.ShardedFileTemplate(
            dataset_name=self.name,
            data_dir=self.data_dir,
            filetype_suffix=parsed_proto.file_format or "tfrecord")
        split_dict = splits_lib.SplitDict.from_proto(
            repeated_split_infos=parsed_proto.splits,
            filename_template=filename_template)
        self.set_splits(split_dict)

        # Restore the feature metadata (vocabulary, labels names,...)
        if self.features:
            self.features.load_metadata(dataset_info_dir)
        # For `ReadOnlyBuilder`, reconstruct the features from the config.
        elif tf.io.gfile.exists(
                feature_lib.make_config_path(dataset_info_dir)):
            self._features = feature_lib.FeatureConnector.from_config(
                dataset_info_dir)
        # Restore the MetaDataDict from metadata.json if there is any
        if (self.metadata is not None
                or tf.io.gfile.exists(_metadata_filepath(dataset_info_dir))):
            # If the dataset was loaded from file, self.metadata will be `None`, so
            # we create a MetadataDict first.
            if self.metadata is None:
                self._metadata = MetadataDict()
            self.metadata.load_metadata(dataset_info_dir)

        # Update fields which are not defined in the code. This means that
        # the code will overwrite fields which are present in
        # dataset_info.json.
        for field_name, field in self.as_proto.DESCRIPTOR.fields_by_name.items(
        ):
            field_value = getattr(self._info_proto, field_name)
            field_value_restored = getattr(parsed_proto, field_name)

            try:
                is_defined = self._info_proto.HasField(field_name)
            except ValueError:
                is_defined = bool(field_value)

            try:
                is_defined_in_restored = parsed_proto.HasField(field_name)
            except ValueError:
                is_defined_in_restored = bool(field_value_restored)

            # If field is defined in code, we ignore the value.
            if is_defined:
                if field_value != field_value_restored:
                    logging.info(
                        "Field info.%s from disk and from code do not match. "
                        "Keeping the one from code.", field_name)
                continue
            # If the field is also not defined in JSON file, we do nothing
            if not is_defined_in_restored:
                continue
            # Otherwise, we restore the dataset_info.json value
            if field.type == field.TYPE_MESSAGE:
                field_value.MergeFrom(field_value_restored)
            else:
                setattr(self._info_proto, field_name, field_value_restored)

        # Mark as fully initialized.
        self._fully_initialized = True