def make_voc_dataset_sub(n_train, n_val, batch_size, eval_batch_size, transform, data_dir=None, prefetch=True): steps_per_epoch, val_steps = n_train // batch_size, n_val // eval_batch_size ds_train = tfds.load("voc/2012", split=f"train[:{n_train}]", data_dir=data_dir, shuffle_files=True, read_config=tfds.ReadConfig(try_autocache=False, skip_prefetch=True)) ds_val = tfds.load("voc/2012", split=f"train[:{n_val}]", data_dir=data_dir, shuffle_files=False, read_config=tfds.ReadConfig(try_autocache=False, skip_prefetch=True)) ds_train = prepare(ds_train, batch_size, transform(training=True), training=True, repeat=False, prefetch=prefetch) ds_val = prepare(ds_val, eval_batch_size, transform(training=False), training=False, repeat=False, drop_remainder=False, prefetch=prefetch) return ds_train, ds_val, steps_per_epoch, val_steps
def _get_tfds_dataset( dataset: str, rng: np.ndarray) -> Tuple[tf.data.Dataset, tf.data.Dataset, int]: """Loads a TFDS dataset.""" dataset_builder = tfds.builder(dataset) num_classes = 0 if "label" in dataset_builder.info.features: num_classes = dataset_builder.info.features["label"].num_classes # Make sure each host uses a different RNG for the training data. rng, data_rng = jax.random.split(rng) data_rng = jax.random.fold_in(data_rng, jax.host_id()) data_rng, shuffle_rng = jax.random.split(data_rng) train_split = deterministic_data.get_read_instruction_for_host( "train", dataset_builder.info.splits["train"].num_examples) train_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[0]) train_ds = dataset_builder.as_dataset(split=train_split, shuffle_files=True, read_config=train_read_config) eval_split_name = { "cifar10": "test", "imagenet2012": "validation" }.get(dataset, "test") eval_split_size = dataset_builder.info.splits[eval_split_name].num_examples eval_split = deterministic_data.get_read_instruction_for_host( eval_split_name, eval_split_size) eval_read_config = tfds.ReadConfig(shuffle_seed=shuffle_rng[1]) eval_ds = dataset_builder.as_dataset(split=eval_split, shuffle_files=False, read_config=eval_read_config) return train_ds, eval_ds, num_classes
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder, tfds_split: Text, tfds_skip_decoding_feature: Text, tfds_as_supervised: bool, input_context: Optional[tf.distribute.InputContext] = None, seed: Optional[Union[int, tf.Tensor]] = None, is_training: bool = False, cache: bool = False, cycle_length: Optional[int] = None, block_length: Optional[int] = None) -> tf.data.Dataset: """Reads a dataset from tfds.""" # No op if exist. tfds_builder.download_and_prepare() decoders = {} if tfds_skip_decoding_feature: for skip_feature in tfds_skip_decoding_feature.split(','): decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() if tfds_builder.info.splits: num_shards = len( tfds_builder.info.splits[tfds_split].file_instructions) else: # The tfds mock path often does not provide splits. num_shards = 1 if input_context and num_shards < input_context.num_input_pipelines: # The number of files in the dataset split is smaller than the number of # input pipelines. We read the entire dataset first and then shard in the # host memory. read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length, interleave_block_length=block_length, input_context=None, shuffle_seed=seed) dataset = tfds_builder.as_dataset(split=tfds_split, shuffle_files=is_training, as_supervised=tfds_as_supervised, decoders=decoders, read_config=read_config) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) else: read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length, interleave_block_length=block_length, input_context=input_context, shuffle_seed=seed) dataset = tfds_builder.as_dataset(split=tfds_split, shuffle_files=is_training, as_supervised=tfds_as_supervised, decoders=decoders, read_config=read_config) if is_training and not cache: dataset = dataset.repeat() return dataset
def __init__(self): import tensorflow_datasets as tfds train = tfds.load( 'imdb_reviews', split='train', read_config=tfds.ReadConfig(shuffle_seed=seed, shuffle_reshuffle_each_iteration=True), ) test = tfds.load( 'imdb_reviews', split='test', read_config=tfds.ReadConfig(shuffle_seed=seed, shuffle_reshuffle_each_iteration=True), ) print(train)
def _load_dataset(split, should_shuffle, data_rng, data_dir): """Loads a dataset split from TFDS.""" if should_shuffle: file_data_rng, dataset_data_rng = jax.random.split(data_rng) file_data_rng = file_data_rng[0] dataset_data_rng = dataset_data_rng[0] else: file_data_rng = None dataset_data_rng = None read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng) dataset = tfds.load('ogbg_molpcba', split='train' if split == 'eval_train' else split, shuffle_files=should_shuffle, read_config=read_config, data_dir=data_dir) if should_shuffle: dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15) dataset = dataset.repeat() # We do not need to worry about repeating the dataset for evaluations because # we call itertools.cycle on the eval iterator, which stored the iterator in # memory to be repeated through. return dataset
def _merge_langs_dataset_fn(split, shuffle_files, tfds_name, langs, seed): """Creates a single dataset containing all languages in a tfds dataset. Loads individual language datasets for the specified languages and concatenates them into a single dataset. This is done so that we can have a single test task which contains all the languages in order to compute a single overall performance metric across all languages. Args: split: string, which split to use. shuffle_files: boolean, whether to shuffle the input files. tfds_name: string, name of the TFDS dataset to load. langs: list of strings, the languages to load. Returns: A tf.data.Dataset. """ ds = [] for lang in langs: ds.append(tfds.load("{}/{}".format(tfds_name, lang), split=split, shuffle_files=shuffle_files, read_config=tfds.ReadConfig( shuffle_seed=seed) )) all_langs_data = ds[0] for lang_data in ds[1:]: all_langs_data = all_langs_data.concatenate(lang_data) return all_langs_data
def test_disable_ordering_guard(self): # Shuffling a sliced dataset should still yield the same examples. source = 'fungi' split = 'valid' label = 13 offset = 994 shuffle_seed = 1234 tfds_dataset = tfds.builder( 'meta_dataset', config=source, data_dir=FLAGS.tfds_path ).as_class_dataset( md_version='v1', meta_split=split, relative_label=label, decoders={'image': tfds.decode.SkipDecoding()}, batch_size=-1, shuffle_files=True, read_config=tfds.ReadConfig( shuffle_seed=shuffle_seed, enable_ordering_guard=False)) tfds_images = tuple(tfds.as_numpy(tfds_dataset)['image']) md_dataset = tf.data.TFRecordDataset( f'{FLAGS.meta_dataset_path}/{source}/{label + offset}.tfrecords' ).map(test_utils.parse_example).batch(10_000, drop_remainder=False) md_data, = tfds.as_numpy(md_dataset) md_images = tuple(md_data['image']) # The order is not the same (`shuffle_files = True`)... self.assertNotEqual(tfds_images, md_images) # ... but the contents are the same. self.assertSameElements(tfds_images, md_images)
def load_tfds(self): logger.info('Using TFDS to load data.') set_hard_limit_num_open_files() self._builder = tfds.builder(self._dataset_name, data_dir=self._dataset_dir) self._builder.download_and_prepare() decoders = {} if self._skip_decoding: decoders['image'] = tfds.decode.SkipDecoding() read_config = tfds.ReadConfig(interleave_cycle_length=64, interleave_block_length=1) dataset = self._builder.as_dataset(split=self._split, as_supervised=True, shuffle_files=True, decoders=decoders, read_config=read_config) return dataset
def __init__(self, version: Literal[10, 20, 100], quantize_bits: int = 8, seed: int = 1): assert version in (10, 20, 100), \ "Only support CIFAR-10, CIFAR-20 and CIFAR-100" self.version = version if version == 10: dsname = 'cifar10' else: dsname = 'cifar100' self.train, self.valid, self.test = tfds.load( name=dsname, split=['train[:48000]', 'train[48000:]', 'test'], # as_supervised=True, read_config=tfds.ReadConfig(shuffle_seed=seed, shuffle_reshuffle_each_iteration=True), shuffle_files=True, with_info=False, ) if version in (10, 100): process = lambda dat: (dat['image'], dat['label']) elif version == 20: process = lambda dat: (dat['image'], dat['coarse_label']) self.train = self.train.map(process) self.valid = self.valid.map(process) self.test = self.test.map(process)
def _load_dataset(split, should_shuffle=False, shuffle_seed=None, shuffle_buffer_size=None): """Loads a dataset split from TFDS.""" if should_shuffle: assert shuffle_seed is not None and shuffle_buffer_size is not None file_shuffle_seed, dataset_shuffle_seed = jax.random.split( shuffle_seed) file_shuffle_seed = file_shuffle_seed[0] dataset_shuffle_seed = dataset_shuffle_seed[0] else: file_shuffle_seed = None dataset_shuffle_seed = None read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_shuffle_seed) dataset = tfds.load('ogbg_molpcba', split=split, shuffle_files=should_shuffle, read_config=read_config) if should_shuffle: dataset = dataset.shuffle(seed=dataset_shuffle_seed, buffer_size=shuffle_buffer_size) dataset = dataset.repeat() return dataset
def large_imagenet(data_dir=None): """Imagenet dataset as a `tf.data.Dataset` pipeline. Note: JPEG decoding is skipped while loading this dataset and hence binary strings representing each image are returned. This is to improve the pipeline performance when additional transformations like random crops are used for training. Dataset homepage: http://www.image-net.org/ Args: data_dir: Directory read/write data from TFDS """ ds, info = tfds.load('imagenet2012', shuffle_files=True, read_config=tfds.ReadConfig(skip_prefetch=True), decoders={'image': tfds.decode.SkipDecoding()}, as_supervised=True, with_info=True, data_dir=data_dir) ds_dict = { "train_ds": ds['train'], "val_ds": ds['validation'], "info": info } if 'minival' in info.splits: ds_dict['minival_ds'] = ds['minival'] return ds_dict
def init_data(): """ Initialize data. """ (ds_train, ds_test), ds_info = tfds.load('mnist', split=['train', 'test'], shuffle_files=True, with_info=True, as_supervised=True, read_config=tfds.ReadConfig(shuffle_seed=parse_args.seed, try_autocache=False)) num_train = ds_info.splits['train'].num_examples num_test = ds_info.splits['test'].num_examples # make sure all batches are the same size to minimize jit compilation cache assert num_train % parse_args.batch_size == 0 num_batches = num_train // parse_args.batch_size assert num_test % parse_args.test_batch_size == 0 num_test_batches = num_test // parse_args.test_batch_size # make sure we always save the model on the last iteration assert num_batches * parse_args.nepochs % parse_args.save_freq == 0 ds_train = ds_train.cache().repeat().shuffle(1000, seed=seed).batch(parse_args.batch_size) ds_test_eval = ds_test.batch(parse_args.test_batch_size).repeat() ds_train, ds_test_eval = tfds.as_numpy(ds_train), tfds.as_numpy(ds_test_eval) meta = { "num_batches": num_batches, "num_test_batches": num_test_batches } return ds_train, ds_test_eval, meta
def load_tfds(self) -> tf.data.Dataset: """Return a dataset loading files from TFDS.""" logging.info('Using TFDS to load data.') builder = tfds.builder(self.config.name, data_dir=self.config.data_dir) if self.config.download: builder.download_and_prepare() decoders = {} if self.config.skip_decoding: decoders['image'] = tfds.decode.SkipDecoding() read_config = tfds.ReadConfig(interleave_cycle_length=10, interleave_block_length=1, input_context=self.input_context) dataset = builder.as_dataset(split=self.config.split, as_supervised=True, shuffle_files=True, decoders=decoders, read_config=read_config) return dataset
def load_shard(self, file_instruction): """Returns a dataset for a single shard of the TFDS TFRecord files.""" ds = self.builder._tfrecords_reader.read_files( # pylint:disable=protected-access [file_instruction], read_config=tfds.ReadConfig(), shuffle_files=False) return ds
def make_dataset(batch_size, eval_batch_size, transform, data_dir=None, drop_remainder=None): n_train, n_val = NUM_EXAMPLES['train'], NUM_EXAMPLES['validation'] steps_per_epoch = n_train // batch_size if drop_remainder: val_steps = n_val // eval_batch_size else: val_steps = math.ceil(n_val / eval_batch_size) read_config = tfds.ReadConfig(try_autocache=False, skip_prefetch=True) ds_train = tfds.load("coco/2017", split=f"train", data_dir=data_dir, shuffle_files=True, read_config=read_config) ds_val = tfds.load("coco/2017", split=f"validation", data_dir=data_dir, shuffle_files=False, read_config=read_config) ds_train = prepare(ds_train, batch_size, transform(training=True), training=True, repeat=False) ds_val = prepare(ds_val, eval_batch_size, transform(training=False), training=False, repeat=False, drop_remainder=drop_remainder) return ds_train, ds_val, steps_per_epoch, val_steps
def _load_tfds(self, *, split, shuffle_seed): return tfds.load( 'cifar10', split={'train': 'train', 'eval': 'test'}[split], shuffle_files=shuffle_seed is not None, read_config=None if shuffle_seed is None else tfds.ReadConfig( shuffle_seed=shuffle_seed))
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder, tfds_split: Text, tfds_skip_decoding_feature: Text, tfds_as_supervised: bool, input_context: Optional[tf.distribute.InputContext] = None, seed: Optional[Union[int, tf.Tensor]] = None, is_training: bool = False, cache: bool = False, cycle_length: Optional[int] = None, block_length: Optional[int] = None) -> tf.data.Dataset: """Reads a dataset from tfds.""" # No op if exist. tfds_builder.download_and_prepare() read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length, interleave_block_length=block_length, input_context=input_context, shuffle_seed=seed) decoders = {} if tfds_skip_decoding_feature: for skip_feature in tfds_skip_decoding_feature.split(','): decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() dataset = tfds_builder.as_dataset(split=tfds_split, shuffle_files=is_training, as_supervised=tfds_as_supervised, decoders=decoders, read_config=read_config) if is_training and not cache: dataset = dataset.repeat() return dataset
def _read_tfds( self, input_context: Optional[tf.distribute.InputContext] = None ) -> tf.data.Dataset: """Reads a dataset from tfds.""" # No op if exist. self._tfds_builder.download_and_prepare() read_config = tfds.ReadConfig( interleave_cycle_length=self._cycle_length, interleave_block_length=self._block_length, input_context=input_context, shuffle_seed=self._seed) decoders = {} if self._tfds_skip_decoding_feature: for skip_feature in self._tfds_skip_decoding_feature.split(','): decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() dataset = self._tfds_builder.as_dataset( split=self._tfds_split, shuffle_files=self._is_training, as_supervised=self._tfds_as_supervised, decoders=decoders, read_config=read_config) # If cache is enabled, we will call `repeat()` later after `cache()`. if self._is_training and not self._cache: dataset = dataset.repeat() return dataset
def __init__(self, image_size: Optional[int] = 28, seed: int = 1): train, valid, test = tfds.load( name='omniglot', split=['train[:90%]', 'train[90%:]', 'test'], read_config=tfds.ReadConfig(shuffle_seed=seed, shuffle_reshuffle_each_iteration=True), as_supervised=True, ) if image_size is None: image_size = 105 image_size = int(image_size) if image_size != 105: @tf.function def resize(x, y): x = tf.image.resize(x, size=(image_size, image_size), method=tf.image.ResizeMethod.BILINEAR, preserve_aspect_ratio=True, antialias=True) y = tf.cast(y, dtype=tf.float32) return x, y train = train.map(resize, tf.data.AUTOTUNE) valid = valid.map(resize, tf.data.AUTOTUNE) test = test.map(resize, tf.data.AUTOTUNE) self.train = train self.valid = valid self.test = test self._image_size = image_size
def _read_tfds( self, input_context: Optional[tf.distribute.InputContext] = None ) -> tf.data.Dataset: """Reads a dataset from tfds.""" if self._tfds_download: self._tfds_builder.download_and_prepare() read_config = tfds.ReadConfig( interleave_cycle_length=self._cycle_length, interleave_block_length=self._block_length, input_context=input_context) decoders = {} if self._tfds_skip_decoding_feature: for skip_feature in self._tfds_skip_decoding_feature.split(','): decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() dataset = self._tfds_builder.as_dataset( split=self._tfds_split, shuffle_files=self._is_training, as_supervised=self._tfds_as_supervised, decoders=decoders, read_config=read_config) if self._is_training: dataset = dataset.repeat() return dataset
def __init__(self): import tensorflow_datasets as tfds train_examples, val_examples = tfds.load( 'math_dataset/arithmetic__mul', split=['train', 'test'], read_config=tfds.ReadConfig(shuffle_seed=seed, shuffle_reshuffle_each_iteration=True), as_supervised=True)
def load(name: str, *, split, data_dir=None, shuffle_files: bool = False): return tfds.load(name, split=split, data_dir=data_dir, download=False, shuffle_files=shuffle_files, read_config=tfds.ReadConfig(try_autocache=False, skip_prefetch=True))
def _load_tfds(self, *, split, shuffle_seed): return tfds.load( 'imagenet2012', split={'train': 'train', 'eval': 'validation'}[split], shuffle_files=shuffle_seed is not None, read_config=None if shuffle_seed is None else tfds.ReadConfig( shuffle_seed=shuffle_seed), decoders={'image': tfds.decode.SkipDecoding()})
def load_dataset(split: str, batch_size: int) -> Iterator[Batch]: ds = tfds.load("binarized_mnist", split=split, shuffle_files=True, read_config=tfds.ReadConfig(shuffle_seed=FLAGS.random_seed)) ds = ds.shuffle(buffer_size=10 * batch_size, seed=FLAGS.random_seed) ds = ds.batch(batch_size) ds = ds.prefetch(buffer_size=5) ds = ds.repeat() return iter(tfds.as_numpy(ds))
def __init__(self): super(MNIST, self).__init__() self.train, self.valid, self.test = tfds.load( name='mnist', split=['train[:55000]', 'train[55000:]', 'test'], read_config=tfds.ReadConfig(shuffle_seed=1, shuffle_reshuffle_each_iteration=True), as_supervised=True)
def _build_autocached_info(self, builder: tfds.core.DatasetBuilder): """Returns the auto-cache information string.""" always_cached = {} never_cached = {} unshuffle_cached = {} for split_name in sorted(builder.info.splits.keys()): split_name = str(split_name) cache_shuffled = builder._should_cache_ds( # pylint: disable=protected-access split_name, shuffle_files=True, read_config=tfds.ReadConfig()) cache_unshuffled = builder._should_cache_ds( # pylint: disable=protected-access split_name, shuffle_files=False, read_config=tfds.ReadConfig()) if all((cache_shuffled, cache_unshuffled)): always_cached[split_name] = None elif not any((cache_shuffled, cache_unshuffled)): never_cached[split_name] = None else: # Dataset is only cached when shuffled_files is False assert not cache_shuffled and cache_unshuffled unshuffle_cached[split_name] = None if not len(builder.info.splits) or not builder.info.dataset_size: # pylint: disable=g-explicit-length-test autocached_info = 'Unknown' elif len(always_cached) == len(builder.info.splits.keys()): autocached_info = 'Yes' # All splits are auto-cached. elif len(never_cached) == len(builder.info.splits.keys()): autocached_info = 'No' # Splits never auto-cached. else: # Some splits cached, some not. autocached_info_parts = [] if always_cached: split_names_str = ', '.join(always_cached) autocached_info_parts.append( 'Yes ({})'.format(split_names_str)) if never_cached: split_names_str = ', '.join(never_cached) autocached_info_parts.append('No ({})'.format(split_names_str)) if unshuffle_cached: split_names_str = ', '.join(unshuffle_cached) autocached_info_parts.append( 'Only when `shuffle_files=False` ({})'.format( split_names_str)) autocached_info = ', '.join(autocached_info_parts) return autocached_info
def test_mocking_add_tfds_id(): read_config = tfds.ReadConfig(add_tfds_id=True) ds = tfds.load('mnist', split='train', read_config=read_config) assert ds.element_spec == { 'tfds_id': tf.TensorSpec(shape=(), dtype=tf.string), 'image': tf.TensorSpec(shape=(28, 28, 1), dtype=tf.uint8), 'label': tf.TensorSpec(shape=(), dtype=tf.int64), } list(ds.take(3)) # Iteration should work
def load_shard(self, file_instruction, shuffle_files=False, seed=None): """Returns a dataset for a single shard of the TFDS TFRecord files.""" # pytype:disable=attribute-error ds = self.builder._tfrecords_reader.read_files( # pylint:disable=protected-access [file_instruction], read_config=tfds.ReadConfig(shuffle_seed=seed), shuffle_files=shuffle_files) # pytype:enable=attribute-error return ds
def _load_tfds(self, *, split, shuffle_seed): tfds_name = {'church': 'lsun/church_outdoor', 'bedroom': 'lsun/bedroom'}[self._subset] return tfds.load( tfds_name, split={'train': 'train', 'eval': 'validation'}[split], shuffle_files=shuffle_seed is not None, read_config=None if shuffle_seed is None else tfds.ReadConfig( shuffle_seed=shuffle_seed), decoders={'image': tfds.decode.SkipDecoding()})
def load_shard(self, file_instruction): """Returns a dataset for a single shard of the TFDS TFRecord files.""" ds = self.builder._tfrecords_reader.read_files( # pylint:disable=protected-access tfds.core.tfrecords_reader.FileInstructions( file_instructions=[file_instruction], num_examples_per_shard=None, ), read_config=tfds.ReadConfig(), shuffle_files=False) return ds