class DummyTFRecordBuilder(dataset_builder.GeneratorBasedBuilder):

  VERSION = utils.Version("0.0.0")

  def _split_generators(self, dl_manager):
    return [
        splits.SplitGenerator(
            name=splits.Split.TRAIN,
            num_shards=2,
            gen_kwargs={"range_": range(20)}),
        splits.SplitGenerator(
            name=splits.Split.VALIDATION,
            num_shards=1,
            gen_kwargs={"range_": range(20, 30)}),
        splits.SplitGenerator(
            name=splits.Split.TEST,
            num_shards=1,
            gen_kwargs={"range_": range(30, 40)}),
    ]

  def _generate_examples(self, range_):
    for i in range_:
      yield {
          "x": i,
          "y": np.array([-i]).astype(np.int64)[0],
          "z": tf.compat.as_text(str(i))
      }

  def _info(self):
    return dataset_info.DatasetInfo(
        builder=self,
        features=features.FeaturesDict({
            "x": tf.int64,
            "y": tf.int64,
            "z": tf.string,
        }),
    )
示例#2
0
class NestedSequenceBuilder(dataset_builder.GeneratorBasedBuilder):
    """Dataset containing nested sequences."""

    VERSION = utils.Version("0.0.1")

    def _info(self):
        return dataset_info.DatasetInfo(
            builder=self,
            features=features.FeaturesDict({
                "frames":
                features.Sequence({
                    "coordinates":
                    features.Sequence(
                        features.Tensor(shape=(2, ), dtype=tf.int32)),
                }),
            }),
        )

    def _split_generators(self, dl_manager):
        del dl_manager
        return [
            splits_lib.SplitGenerator(
                name=splits_lib.Split.TRAIN,
                gen_kwargs={},
            ),
        ]

    def _generate_examples(self):
        ex0 = [[[0, 1], [2, 3], [4, 5]], [], [[6, 7]]]
        ex1 = []
        ex2 = [
            [[10, 11]],
            [[12, 13], [14, 15]],
        ]
        for i, ex in enumerate([ex0, ex1, ex2]):
            yield i, {"frames": {"coordinates": ex}}
示例#3
0
class DummyMnist(dataset_builder.GeneratorBasedBuilder):
    """Test DatasetBuilder."""

    VERSION = utils.Version("1.0.0")

    def __init__(self, *args, **kwargs):
        self._num_shards = kwargs.pop("num_shards", 10)
        super(DummyMnist, self).__init__(*args, **kwargs)

    def _info(self):
        return dataset_info.DatasetInfo(
            builder=self,
            features=features.FeaturesDict({
                "image":
                features.Image(shape=(28, 28, 1)),
                "label":
                features.ClassLabel(num_classes=10),
            }),
        )

    def _split_generators(self, dl_manager):
        return [
            splits.SplitGenerator(name=splits.Split.TRAIN,
                                  num_shards=self._num_shards,
                                  gen_kwargs=dict()),
            splits.SplitGenerator(name=splits.Split.TEST,
                                  num_shards=1,
                                  gen_kwargs=dict()),
        ]

    def _generate_examples(self):
        for i in range(20):
            yield {
                "image": np.ones((28, 28, 1), dtype=np.uint8),
                "label": i % 10,
            }
示例#4
0
class DummyDatasetSharedGenerator(dataset_builder.GeneratorBasedBuilder):
  """Test DatasetBuilder."""

  VERSION = utils.Version("1.0.0")

  def _info(self):
    return dataset_info.DatasetInfo(
        builder=self,
        features=features.FeaturesDict({"x": tf.int64}),
        supervised_keys=("x", "x"),
    )

  def _split_generators(self, dl_manager):
    # Split the 30 examples from the generator into 2 train shards and 1 test
    # shard.
    del dl_manager
    return [splits.SplitGenerator(
        name=[splits.Split.TRAIN, splits.Split.TEST],
        num_shards=[2, 1],
    )]

  def _generate_examples(self):
    for i in range(30):
      yield {"x": i}
示例#5
0
 def version(self):
     return utils.Version(self.as_proto.version)
示例#6
0
 def _is_version_valid(version):
     try:
         return utils.Version(version) and True
     except ValueError:  # Invalid version (ex: incomplete data dir)
         return False
示例#7
0
 def versions(self):
     """Versions (canonical + availables), in preference order."""
     return [
         utils.Version(v) if isinstance(v, six.string_types) else v
         for v in [self.canonical_version] + self.supported_versions
     ]
示例#8
0
 def _pick_version(self, version: str) -> utils.Version:
     return utils.Version(version)
class FaultyS3DummyBeamDataset(DummyBeamDataset):

    VERSION = utils.Version("1.0.0")
示例#10
0
 def canonical_version(self) -> utils.Version:
     if self._builder_config and self._builder_config.version:
         return utils.Version(self._builder_config.version)
     else:
         return utils.Version(self.VERSION)
示例#11
0
 class MultiDirDataset(DummyDatasetSharedGenerator):  # pylint: disable=unused-variable
     VERSION = utils.Version("1.2.0")
示例#12
0
def write_metadata(
    *,
    data_dir: epath.PathLike,
    features: features_lib.feature.FeatureConnectorArg,
    split_infos: Union[None, epath.PathLike, List[split_lib.SplitInfo]] = None,
    version: Union[None, str, utils.Version] = None,
    check_data: bool = True,
    **ds_info_kwargs,
) -> None:
    """Add metadata required to load with TFDS.

  See documentation for usage:
  https://www.tensorflow.org/datasets/external_tfrecord

  Args:
    data_dir: Dataset path on which to save the metadata
    features: dict of `tfds.features.FeatureConnector` matching the proto specs.
    split_infos: Can be either:  * A path to the pre-computed split info values
      ( the `out_dir` kwarg of `tfds.folder_dataset.compute_split_info`) * A
      list of `tfds.core.SplitInfo` (returned value of
      `tfds.folder_dataset.compute_split_info`) * `None` to auto-compute the
      split info.
    version: Optional dataset version (auto-infer by default, or fallback to
      1.0.0)
    check_data: If True, perform additional check to validate the data in
      data_dir is valid
    **ds_info_kwargs: Additional metadata forwarded to `tfds.core.DatasetInfo` (
      description, homepage,...). Will appear in the doc.
  """
    features = features_lib.features_dict.to_feature(features)
    data_dir = epath.Path(data_dir)
    # Extract the tf-record filenames
    tfrecord_files = [
        f for f in data_dir.iterdir() if naming.FilenameInfo.is_valid(f.name)
    ]
    if not tfrecord_files:
        raise ValueError(
            f'Could not find tf-record (or compatible format) in {data_dir}. '
            'Make sure to follow the pattern: '
            '`<dataset_name>-<split_name>.<file-extension>-xxxxxx-of-yyyyyy`')

    file_infos = [naming.FilenameInfo.from_str(f.name) for f in tfrecord_files]

    # Use set with tuple expansion syntax to ensure all names are consistents
    snake_name, = {f.dataset_name for f in file_infos}
    camel_name = naming.snake_to_camelcase(snake_name)
    filetype_suffix, = {f.filetype_suffix for f in file_infos}
    file_format = file_adapters.file_format_from_suffix(filetype_suffix)

    cls = types.new_class(
        camel_name,
        bases=(_WriteBuilder, ),
        kwds=dict(skip_registration=True),
        exec_body=None,
    )

    if version is None:  # Automatically detect the version
        if utils.Version.is_valid(data_dir.name):
            version = data_dir.name
        else:
            version = '1.0.0'
    cls.VERSION = utils.Version(version)

    # Create a dummy builder (use non existant folder to make sure
    # dataset_info.json is not restored)
    builder = cls(file_format=file_format, data_dir='/tmp/non-existent-dir/')

    # Create the metadata
    ds_info = dataset_info.DatasetInfo(
        builder=builder,
        features=features,
        **ds_info_kwargs,
    )
    ds_info.set_file_format(file_format)

    # Add the split infos
    split_dict = _load_splits(
        data_dir=data_dir,
        split_infos=split_infos,
        file_infos=file_infos,
        filetype_suffix=filetype_suffix,
        builder=builder,
    )
    ds_info.set_splits(split_dict)

    # Save all metadata (dataset_info.json, features.json,...)
    ds_info.write_to_directory(data_dir)

    # Make sure that the data can be loaded (feature connector match the actual
    # specs)
    if check_data:
        utils.print_notebook(
            'Metadata written. Testing by reading first example. '
            'Set check_data=False to skip.')
        builder = read_only_builder.builder_from_directory(data_dir)
        split_name = next(iter(builder.info.splits))
        _, = builder.as_dataset(
            split=f'{split_name}[:1]')  # Load the first example
示例#13
0
class DummyNoConfMnist(testing.DummyDataset):
  """Same as DummyMnist (but declared here to avoid skip_registering issues)."""
  VERSION = utils.Version('0.1.0')
示例#14
0
    def test_restore_after_modification(self):
        # Create a DatasetInfo
        info = dataset_info.DatasetInfo(
            builder=self._builder,
            description="A description",
            supervised_keys=("input", "output"),
            homepage="http://some-location",
            citation="some citation",
            license="some license",
        )
        info.download_size = 456
        filepath_template = "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}"
        info.as_proto.splits.add(name="train",
                                 num_bytes=512,
                                 filepath_template=filepath_template)
        info.as_proto.splits.add(name="validation",
                                 num_bytes=64,
                                 filepath_template=filepath_template)
        info.as_proto.schema.feature.add()
        info.as_proto.schema.feature.add()  # Add dynamic statistics
        info.download_checksums = {
            "url1": "some checksum",
            "url2": "some other checksum",
        }

        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            # Save it
            info.write_to_directory(tmp_dir)

            # If fields are not defined, then everything is restored from disk
            restored_info = dataset_info.DatasetInfo(builder=self._builder)
            restored_info.read_from_directory(tmp_dir)
            self.assertProtoEquals(info.as_proto, restored_info.as_proto)

        with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
            # Save it
            info.write_to_directory(tmp_dir)

            # If fields are defined, then the code version is kept
            restored_info = dataset_info.DatasetInfo(
                builder=self._builder,
                supervised_keys=("input (new)", "output (new)"),
                homepage="http://some-location-new",
                citation="some citation (new)",
                redistribution_info={"license": "some license (new)"})
            restored_info.download_size = 789
            restored_info.as_proto.splits.add(name="validation", num_bytes=288)
            restored_info.as_proto.schema.feature.add()
            restored_info.as_proto.schema.feature.add()
            restored_info.as_proto.schema.feature.add()
            restored_info.as_proto.schema.feature.add(
            )  # Add dynamic statistics
            restored_info.download_checksums = {
                "url2": "some other checksum (new)",
                "url3": "some checksum (new)",
            }

            restored_info.read_from_directory(tmp_dir)

            # Even though restored_info has been restored, informations defined in
            # the code overwrite informations from the json file.
            self.assertEqual(restored_info.description, "A description")
            self.assertEqual(restored_info.version, utils.Version("3.0.1"))
            self.assertEqual(restored_info.release_notes, {})
            self.assertEqual(restored_info.supervised_keys,
                             ("input (new)", "output (new)"))
            self.assertEqual(restored_info.homepage,
                             "http://some-location-new")
            self.assertEqual(restored_info.citation, "some citation (new)")
            self.assertEqual(restored_info.redistribution_info.license,
                             "some license (new)")
            self.assertEqual(restored_info.download_size, 789)
            self.assertEqual(restored_info.dataset_size, 576)
            self.assertEqual(len(restored_info.as_proto.schema.feature), 4)
            self.assertEqual(
                restored_info.download_checksums, {
                    "url2": "some other checksum (new)",
                    "url3": "some checksum (new)",
                })