def _build_single_dataset(self, split, shuffle_files, batch_size, decoders, as_supervised, in_memory): """as_dataset for a single split.""" if isinstance(split, six.string_types): split = splits_lib.Split(split) wants_full_dataset = batch_size == -1 if wants_full_dataset: batch_size = self.info.splits.total_num_examples or sys.maxsize # If the dataset is small, load it in memory dataset_shape_is_fully_defined = ( dataset_utils.features_shape_is_fully_defined(self.info.features)) in_memory_default = False # TODO(tfds): Consider default in_memory=True for small datasets with # fully-defined shape. # Expose and use the actual data size on disk and rm the manual # name guards. size_in_bytes is the download size, which is misleading, # particularly for datasets that use manual_dir as well as some downloads # (wmt and diabetic_retinopathy_detection). # in_memory_default = ( # self.info.size_in_bytes and # self.info.size_in_bytes <= 1e9 and # not self.name.startswith("wmt") and # not self.name.startswith("diabetic") and # dataset_shape_is_fully_defined) in_memory = in_memory_default if in_memory is None else in_memory # Build base dataset if in_memory and not wants_full_dataset: # TODO(tfds): Enable in_memory without padding features. May be able # to do by using a requested version of tf.data.Dataset.cache that can # persist a cache beyond iterator instances. if not dataset_shape_is_fully_defined: logging.warning( "Called in_memory=True on a dataset that does not " "have fully defined shapes. Note that features with " "variable length dimensions will be 0-padded to " "the maximum length across the dataset.") full_bs = self.info.splits.total_num_examples or sys.maxsize # If using in_memory, escape all device contexts so we can load the data # with a local Session. with tf.device(None): dataset = self._as_dataset(split=split, shuffle_files=shuffle_files, decoders=decoders) # Use padded_batch so that features with unknown shape are supported. dataset = dataset.padded_batch( full_bs, tf.compat.v1.data.get_output_shapes(dataset)) dataset = tf.data.Dataset.from_tensor_slices( next(dataset_utils.as_numpy(dataset))) else: dataset = self._as_dataset(split=split, shuffle_files=shuffle_files, decoders=decoders) if batch_size: # Use padded_batch so that features with unknown shape are supported. dataset = dataset.padded_batch( batch_size, tf.compat.v1.data.get_output_shapes(dataset)) if as_supervised: if not self.info.supervised_keys: raise ValueError( "as_supervised=True but %s does not support a supervised " "(input, label) structure." % self.name) input_f, target_f = self.info.supervised_keys dataset = dataset.map( lambda fs: (fs[input_f], fs[target_f]), num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # If shuffling, allow pipeline to be non-deterministic options = tf.data.Options() options.experimental_deterministic = not shuffle_files dataset = dataset.with_options(options) if wants_full_dataset: return tf.data.experimental.get_single_element(dataset) return dataset
def _build_single_dataset( self, split, shuffle_files, batch_size, decoders, read_config, as_supervised, in_memory): """as_dataset for a single split.""" if isinstance(split, six.string_types): split = splits_lib.Split(split) wants_full_dataset = batch_size == -1 if wants_full_dataset: batch_size = self.info.splits.total_num_examples or sys.maxsize # Build base dataset if in_memory and not wants_full_dataset: # TODO(tfds): Remove once users have been migrated # If the dataset is small, load it in memory logging.warning( "`in_memory` if deprecated and will be removed in a future version. " "Please use `ds = ds.cache()` instead.") # TODO(tfds): Enable in_memory without padding features. May be able # to do by using a requested version of tf.data.Dataset.cache that can # persist a cache beyond iterator instances. dataset_shape_is_fully_defined = ( dataset_utils.features_shape_is_fully_defined(self.info.features)) if not dataset_shape_is_fully_defined: logging.warning("Called in_memory=True on a dataset that does not " "have fully defined shapes. Note that features with " "variable length dimensions will be 0-padded to " "the maximum length across the dataset.") full_bs = self.info.splits.total_num_examples or sys.maxsize # If using in_memory, escape all device contexts so we can load the data # with a local Session. with tf.device(None): ds = self._as_dataset( split=split, shuffle_files=shuffle_files, decoders=decoders, read_config=read_config, ) # Use padded_batch so that features with unknown shape are supported. ds = ds.padded_batch( full_bs, tf.compat.v1.data.get_output_shapes(ds)) ds = tf.compat.v1.data.Dataset.from_tensor_slices( next(dataset_utils.as_numpy(ds))) else: ds = self._as_dataset( split=split, shuffle_files=shuffle_files, decoders=decoders, read_config=read_config, ) # Auto-cache small datasets which are small enough to fit in memory. if self._should_cache_ds( split=split, shuffle_files=shuffle_files, read_config=read_config ): ds = ds.cache() if batch_size: # Use padded_batch so that features with unknown shape are supported. ds = ds.padded_batch( batch_size, tf.compat.v1.data.get_output_shapes(ds)) if as_supervised: if not self.info.supervised_keys: raise ValueError( "as_supervised=True but %s does not support a supervised " "(input, label) structure." % self.name) input_f, target_f = self.info.supervised_keys ds = ds.map(lambda fs: (fs[input_f], fs[target_f]), num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) # If shuffling is True and seeds not set, allow pipeline to be # non-deterministic # This code should probably be moved inside tfreader, such as # all the tf.data.Options are centralized in a single place. if (shuffle_files and read_config.options.experimental_deterministic is None and read_config.shuffle_seed is None): options = tf.data.Options() options.experimental_deterministic = False ds = ds.with_options(options) # If shuffle is False, keep the default value (deterministic), which # allow the user to overwritte it. if wants_full_dataset: return tf.data.experimental.get_single_element(ds) return ds