def test_read_config(self): is_called = [] def interleave_sort(lists): is_called.append(True) return lists with testing.tmp_dir(self.get_temp_dir()) as tmp_dir: read_config = read_config_lib.ReadConfig( experimental_interleave_sort_fn=interleave_sort, ) read_config.options.experimental_stats.prefix = "tfds_prefix" ds = registered.load( name="dummy_dataset_shared_generator", data_dir=tmp_dir, split="train", read_config=read_config, shuffle_files=True, ) # Check that the ReadConfig options are properly set self.assertEqual(ds.options().experimental_stats.prefix, "tfds_prefix") # The instruction function should have been called self.assertEqual(is_called, [True])
def as_dataset( self, split=None, batch_size=None, shuffle_files=False, decoders=None, read_config=None, as_supervised=False, ): # pylint: disable=line-too-long """Constructs a `tf.data.Dataset`. Callers must pass arguments as keyword arguments. The output types vary depending on the parameters. Examples: ```python builder = tfds.builder('imdb_reviews') builder.download_and_prepare() # Default parameters: Returns the dict of tf.data.Dataset ds_all_dict = builder.as_dataset() assert isinstance(ds_all_dict, dict) print(ds_all_dict.keys()) # ==> ['test', 'train', 'unsupervised'] assert isinstance(ds_all_dict['test'], tf.data.Dataset) # Each dataset (test, train, unsup.) consists of dictionaries # {'label': <tf.Tensor: .. dtype=int64, numpy=1>, # 'text': <tf.Tensor: .. dtype=string, numpy=b"I've watched the movie ..">} # {'label': <tf.Tensor: .. dtype=int64, numpy=1>, # 'text': <tf.Tensor: .. dtype=string, numpy=b'If you love Japanese ..'>} # With as_supervised: tf.data.Dataset only contains (feature, label) tuples ds_all_supervised = builder.as_dataset(as_supervised=True) assert isinstance(ds_all_supervised, dict) print(ds_all_supervised.keys()) # ==> ['test', 'train', 'unsupervised'] assert isinstance(ds_all_supervised['test'], tf.data.Dataset) # Each dataset (test, train, unsup.) consists of tuples (text, label) # (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) # (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) # Same as above plus requesting a particular split ds_test_supervised = builder.as_dataset(as_supervised=True, split='test') assert isinstance(ds_test_supervised, tf.data.Dataset) # The dataset consists of tuples (text, label) # (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) # (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">, # <tf.Tensor: ... dtype=int64, numpy=1>) ``` Args: split: Which split of the data to load (e.g. `'train'`, `'test'` `['train', 'test']`, `'train[80%:]'`,...). See our [split API guide](https://www.tensorflow.org/datasets/splits). If `None`, will return all splits in a `Dict[Split, tf.data.Dataset]`. batch_size: `int`, batch size. Note that variable-length features will be 0-padded if `batch_size` is set. Users that want more custom behavior should use `batch_size=None` and use the `tf.data` API to construct a custom pipeline. If `batch_size == -1`, will return feature dictionaries of the whole dataset with `tf.Tensor`s instead of a `tf.data.Dataset`. shuffle_files: `bool`, whether to shuffle the input files. Defaults to `False`. decoders: Nested dict of `Decoder` objects which allow to customize the decoding. The structure should match the feature structure, but only customized feature keys need to be present. See [the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md) for more info. read_config: `tfds.ReadConfig`, Additional options to configure the input pipeline (e.g. seed, num parallel reads,...). as_supervised: `bool`, if `True`, the returned `tf.data.Dataset` will have a 2-tuple structure `(input, label)` according to `builder.info.supervised_keys`. If `False`, the default, the returned `tf.data.Dataset` will have a dictionary with all the features. Returns: `tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value: tfds.data.Dataset>`. If `batch_size` is -1, will return feature dictionaries containing the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`. """ # pylint: enable=line-too-long logging.info("Constructing tf.data.Dataset for split %s, from %s", split, self._data_dir) if not tf.io.gfile.exists(self._data_dir): raise AssertionError(( "Dataset %s: could not find data in %s. Please make sure to call " "dataset_builder.download_and_prepare(), or pass download=True to " "tfds.load() before trying to access the tf.data.Dataset object." ) % (self.name, self._data_dir_root)) # By default, return all splits if split is None: split = {s: s for s in self.info.splits} read_config = read_config or read_config_lib.ReadConfig() # Create a dataset for each of the given splits build_single_dataset = functools.partial( self._build_single_dataset, shuffle_files=shuffle_files, batch_size=batch_size, decoders=decoders, read_config=read_config, as_supervised=as_supervised, ) datasets = utils.map_nested(build_single_dataset, split, map_tuple=True) return datasets