class DummyTFRecordBuilder(dataset_builder.GeneratorBasedBuilder): VERSION = utils.Version("0.0.0") def _split_generators(self, dl_manager): return [ splits.SplitGenerator( name=splits.Split.TRAIN, num_shards=2, gen_kwargs={"range_": range(20)}), splits.SplitGenerator( name=splits.Split.VALIDATION, num_shards=1, gen_kwargs={"range_": range(20, 30)}), splits.SplitGenerator( name=splits.Split.TEST, num_shards=1, gen_kwargs={"range_": range(30, 40)}), ] def _generate_examples(self, range_): for i in range_: yield { "x": i, "y": np.array([-i]).astype(np.int64)[0], "z": tf.compat.as_text(str(i)) } def _info(self): return dataset_info.DatasetInfo( builder=self, features=features.FeaturesDict({ "x": tf.int64, "y": tf.int64, "z": tf.string, }), )
class NestedSequenceBuilder(dataset_builder.GeneratorBasedBuilder): """Dataset containing nested sequences.""" VERSION = utils.Version("0.0.1") def _info(self): return dataset_info.DatasetInfo( builder=self, features=features.FeaturesDict({ "frames": features.Sequence({ "coordinates": features.Sequence( features.Tensor(shape=(2, ), dtype=tf.int32)), }), }), ) def _split_generators(self, dl_manager): del dl_manager return [ splits_lib.SplitGenerator( name=splits_lib.Split.TRAIN, gen_kwargs={}, ), ] def _generate_examples(self): ex0 = [[[0, 1], [2, 3], [4, 5]], [], [[6, 7]]] ex1 = [] ex2 = [ [[10, 11]], [[12, 13], [14, 15]], ] for i, ex in enumerate([ex0, ex1, ex2]): yield i, {"frames": {"coordinates": ex}}
class DummyMnist(dataset_builder.GeneratorBasedBuilder): """Test DatasetBuilder.""" VERSION = utils.Version("1.0.0") def __init__(self, *args, **kwargs): self._num_shards = kwargs.pop("num_shards", 10) super(DummyMnist, self).__init__(*args, **kwargs) def _info(self): return dataset_info.DatasetInfo( builder=self, features=features.FeaturesDict({ "image": features.Image(shape=(28, 28, 1)), "label": features.ClassLabel(num_classes=10), }), ) def _split_generators(self, dl_manager): return [ splits.SplitGenerator(name=splits.Split.TRAIN, num_shards=self._num_shards, gen_kwargs=dict()), splits.SplitGenerator(name=splits.Split.TEST, num_shards=1, gen_kwargs=dict()), ] def _generate_examples(self): for i in range(20): yield { "image": np.ones((28, 28, 1), dtype=np.uint8), "label": i % 10, }
class DummyDatasetSharedGenerator(dataset_builder.GeneratorBasedBuilder): """Test DatasetBuilder.""" VERSION = utils.Version("1.0.0") def _info(self): return dataset_info.DatasetInfo( builder=self, features=features.FeaturesDict({"x": tf.int64}), supervised_keys=("x", "x"), ) def _split_generators(self, dl_manager): # Split the 30 examples from the generator into 2 train shards and 1 test # shard. del dl_manager return [splits.SplitGenerator( name=[splits.Split.TRAIN, splits.Split.TEST], num_shards=[2, 1], )] def _generate_examples(self): for i in range(30): yield {"x": i}
def version(self): return utils.Version(self.as_proto.version)
def _is_version_valid(version): try: return utils.Version(version) and True except ValueError: # Invalid version (ex: incomplete data dir) return False
def versions(self): """Versions (canonical + availables), in preference order.""" return [ utils.Version(v) if isinstance(v, six.string_types) else v for v in [self.canonical_version] + self.supported_versions ]
def _pick_version(self, version: str) -> utils.Version: return utils.Version(version)
class FaultyS3DummyBeamDataset(DummyBeamDataset): VERSION = utils.Version("1.0.0")
def canonical_version(self) -> utils.Version: if self._builder_config and self._builder_config.version: return utils.Version(self._builder_config.version) else: return utils.Version(self.VERSION)
class MultiDirDataset(DummyDatasetSharedGenerator): # pylint: disable=unused-variable VERSION = utils.Version("1.2.0")
def write_metadata( *, data_dir: epath.PathLike, features: features_lib.feature.FeatureConnectorArg, split_infos: Union[None, epath.PathLike, List[split_lib.SplitInfo]] = None, version: Union[None, str, utils.Version] = None, check_data: bool = True, **ds_info_kwargs, ) -> None: """Add metadata required to load with TFDS. See documentation for usage: https://www.tensorflow.org/datasets/external_tfrecord Args: data_dir: Dataset path on which to save the metadata features: dict of `tfds.features.FeatureConnector` matching the proto specs. split_infos: Can be either: * A path to the pre-computed split info values ( the `out_dir` kwarg of `tfds.folder_dataset.compute_split_info`) * A list of `tfds.core.SplitInfo` (returned value of `tfds.folder_dataset.compute_split_info`) * `None` to auto-compute the split info. version: Optional dataset version (auto-infer by default, or fallback to 1.0.0) check_data: If True, perform additional check to validate the data in data_dir is valid **ds_info_kwargs: Additional metadata forwarded to `tfds.core.DatasetInfo` ( description, homepage,...). Will appear in the doc. """ features = features_lib.features_dict.to_feature(features) data_dir = epath.Path(data_dir) # Extract the tf-record filenames tfrecord_files = [ f for f in data_dir.iterdir() if naming.FilenameInfo.is_valid(f.name) ] if not tfrecord_files: raise ValueError( f'Could not find tf-record (or compatible format) in {data_dir}. ' 'Make sure to follow the pattern: ' '`<dataset_name>-<split_name>.<file-extension>-xxxxxx-of-yyyyyy`') file_infos = [naming.FilenameInfo.from_str(f.name) for f in tfrecord_files] # Use set with tuple expansion syntax to ensure all names are consistents snake_name, = {f.dataset_name for f in file_infos} camel_name = naming.snake_to_camelcase(snake_name) filetype_suffix, = {f.filetype_suffix for f in file_infos} file_format = file_adapters.file_format_from_suffix(filetype_suffix) cls = types.new_class( camel_name, bases=(_WriteBuilder, ), kwds=dict(skip_registration=True), exec_body=None, ) if version is None: # Automatically detect the version if utils.Version.is_valid(data_dir.name): version = data_dir.name else: version = '1.0.0' cls.VERSION = utils.Version(version) # Create a dummy builder (use non existant folder to make sure # dataset_info.json is not restored) builder = cls(file_format=file_format, data_dir='/tmp/non-existent-dir/') # Create the metadata ds_info = dataset_info.DatasetInfo( builder=builder, features=features, **ds_info_kwargs, ) ds_info.set_file_format(file_format) # Add the split infos split_dict = _load_splits( data_dir=data_dir, split_infos=split_infos, file_infos=file_infos, filetype_suffix=filetype_suffix, builder=builder, ) ds_info.set_splits(split_dict) # Save all metadata (dataset_info.json, features.json,...) ds_info.write_to_directory(data_dir) # Make sure that the data can be loaded (feature connector match the actual # specs) if check_data: utils.print_notebook( 'Metadata written. Testing by reading first example. ' 'Set check_data=False to skip.') builder = read_only_builder.builder_from_directory(data_dir) split_name = next(iter(builder.info.splits)) _, = builder.as_dataset( split=f'{split_name}[:1]') # Load the first example
class DummyNoConfMnist(testing.DummyDataset): """Same as DummyMnist (but declared here to avoid skip_registering issues).""" VERSION = utils.Version('0.1.0')
def test_restore_after_modification(self): # Create a DatasetInfo info = dataset_info.DatasetInfo( builder=self._builder, description="A description", supervised_keys=("input", "output"), homepage="http://some-location", citation="some citation", license="some license", ) info.download_size = 456 filepath_template = "{DATASET}-{SPLIT}.{FILEFORMAT}-{SHARD_X_OF_Y}" info.as_proto.splits.add(name="train", num_bytes=512, filepath_template=filepath_template) info.as_proto.splits.add(name="validation", num_bytes=64, filepath_template=filepath_template) info.as_proto.schema.feature.add() info.as_proto.schema.feature.add() # Add dynamic statistics info.download_checksums = { "url1": "some checksum", "url2": "some other checksum", } with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: # Save it info.write_to_directory(tmp_dir) # If fields are not defined, then everything is restored from disk restored_info = dataset_info.DatasetInfo(builder=self._builder) restored_info.read_from_directory(tmp_dir) self.assertProtoEquals(info.as_proto, restored_info.as_proto) with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: # Save it info.write_to_directory(tmp_dir) # If fields are defined, then the code version is kept restored_info = dataset_info.DatasetInfo( builder=self._builder, supervised_keys=("input (new)", "output (new)"), homepage="http://some-location-new", citation="some citation (new)", redistribution_info={"license": "some license (new)"}) restored_info.download_size = 789 restored_info.as_proto.splits.add(name="validation", num_bytes=288) restored_info.as_proto.schema.feature.add() restored_info.as_proto.schema.feature.add() restored_info.as_proto.schema.feature.add() restored_info.as_proto.schema.feature.add( ) # Add dynamic statistics restored_info.download_checksums = { "url2": "some other checksum (new)", "url3": "some checksum (new)", } restored_info.read_from_directory(tmp_dir) # Even though restored_info has been restored, informations defined in # the code overwrite informations from the json file. self.assertEqual(restored_info.description, "A description") self.assertEqual(restored_info.version, utils.Version("3.0.1")) self.assertEqual(restored_info.release_notes, {}) self.assertEqual(restored_info.supervised_keys, ("input (new)", "output (new)")) self.assertEqual(restored_info.homepage, "http://some-location-new") self.assertEqual(restored_info.citation, "some citation (new)") self.assertEqual(restored_info.redistribution_info.license, "some license (new)") self.assertEqual(restored_info.download_size, 789) self.assertEqual(restored_info.dataset_size, 576) self.assertEqual(len(restored_info.as_proto.schema.feature), 4) self.assertEqual( restored_info.download_checksums, { "url2": "some other checksum (new)", "url3": "some checksum (new)", })