def _read_tfds(tfds_builder: tfds.core.DatasetBuilder, tfds_split: Text, tfds_skip_decoding_feature: Text, tfds_as_supervised: bool, input_context: Optional[tf.distribute.InputContext] = None, seed: Optional[Union[int, tf.Tensor]] = None, is_training: bool = False, cache: bool = False, cycle_length: Optional[int] = None, block_length: Optional[int] = None) -> tf.data.Dataset: """Reads a dataset from tfds.""" # No op if exist. tfds_builder.download_and_prepare() decoders = {} if tfds_skip_decoding_feature: for skip_feature in tfds_skip_decoding_feature.split(','): decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() if tfds_builder.info.splits: num_shards = len( tfds_builder.info.splits[tfds_split].file_instructions) else: # The tfds mock path often does not provide splits. num_shards = 1 if input_context and num_shards < input_context.num_input_pipelines: # The number of files in the dataset split is smaller than the number of # input pipelines. We read the entire dataset first and then shard in the # host memory. read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length, interleave_block_length=block_length, input_context=None, shuffle_seed=seed) dataset = tfds_builder.as_dataset(split=tfds_split, shuffle_files=is_training, as_supervised=tfds_as_supervised, decoders=decoders, read_config=read_config) dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) else: read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length, interleave_block_length=block_length, input_context=input_context, shuffle_seed=seed) dataset = tfds_builder.as_dataset(split=tfds_split, shuffle_files=is_training, as_supervised=tfds_as_supervised, decoders=decoders, read_config=read_config) if is_training and not cache: dataset = dataset.repeat() return dataset
def create_split(dataset_builder: tfds.core.DatasetBuilder, batch_size: int, train: bool, dtype: tf.DType = tf.float32, image_size: int = IMAGE_SIZE, cache: bool = False): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: dataset_builder: TFDS dataset builder for ImageNet. batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. dtype: data type of the image (default: float32). image_size: The target size of the images (default: 224). cache: Whether to cache the dataset (default: False). Returns: A `tf.data.Dataset`. """ if train: train_size = dataset_builder.info.splits['train'].num_examples split_size = train_size // jax.host_count() start = jax.host_id() * split_size split = 'train[{}:{}]'.format(start, start + split_size) else: validation_size = dataset_builder.info.splits[ 'validation'].num_examples split_size = validation_size // jax.host_count() start = jax.host_id() * split_size split = 'validation[{}:{}]'.format(start, start + split_size) def _decode_example(example): if train: image = preprocess_for_train(example['image'], dtype, image_size) else: image = preprocess_for_eval(example['image'], dtype, image_size) return {'image': image, 'label': example['label']} ds = dataset_builder.as_dataset( split=split, decoders={'image': tfds.decode.SkipDecoding()}) ds.options().experimental_threading.private_threadpool_size = 48 ds.options().experimental_threading.max_intra_op_parallelism = 1 if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(_decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train: ds = ds.repeat() ds = ds.prefetch(10) return ds
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder, tfds_split: Text, tfds_skip_decoding_feature: Text, tfds_as_supervised: bool, input_context: Optional[tf.distribute.InputContext] = None, seed: Optional[Union[int, tf.Tensor]] = None, is_training: bool = False, cache: bool = False, cycle_length: Optional[int] = None, block_length: Optional[int] = None) -> tf.data.Dataset: """Reads a dataset from tfds.""" # No op if exist. tfds_builder.download_and_prepare() read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length, interleave_block_length=block_length, input_context=input_context, shuffle_seed=seed) decoders = {} if tfds_skip_decoding_feature: for skip_feature in tfds_skip_decoding_feature.split(','): decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() dataset = tfds_builder.as_dataset(split=tfds_split, shuffle_files=is_training, as_supervised=tfds_as_supervised, decoders=decoders, read_config=read_config) if is_training and not cache: dataset = dataset.repeat() return dataset
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, split: str, *, reverse_translation: bool = False) -> tf.data.Dataset: """Loads a raw WMT dataset and normalizes feature keys. Args: dataset_builder: TFDS dataset builder that can build `slit`. split: Split to use. This must be the full split. We shard the split across multiple hosts and currently don't support sharding subsplits. reverse_translation: bool: whether to reverse the translation direction. e.g. for 'de-en' this translates from english to german. Returns: Dataset with source and target language features mapped to 'inputs' and 'targets'. """ num_examples = dataset_builder.info.splits[split].num_examples per_host_split = deterministic_data.get_read_instruction_for_host( split, num_examples, drop_remainder=False) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map(NormalizeFetaureNamesOp( dataset_builder.info, reverse_translation=reverse_translation), num_parallel_calls=AUTOTUNE) return ds
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, split: str) -> tf.data.Dataset: """Loads a raw text dataset and normalizes feature keys. Args: dataset_builder: TFDS dataset builder that can build `split`. split: Split to use. This must be the full split. We shard the split across multiple hosts and currently don't support sharding subsplits. Returns: Dataset with source and target language features mapped to 'inputs' and 'targets'. """ per_host_split = deterministic_data.get_read_instruction_for_host( split, dataset_info=dataset_builder.info, drop_remainder=False) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map(NormalizeFeatureNamesOp(dataset_builder.info), num_parallel_calls=AUTOTUNE) return ds
def create_split( dataset_builder: tfds.core.DatasetBuilder, batch_size: int, train: bool = True, half_precision: bool = False, image_size: int = IMAGE_SIZE, mean: Optional[Tuple[float]] = None, std: Optional[Tuple[float]] = None, interpolation: str = 'bicubic', augment_name: Optional[str] = None, randaug_num_layers: Optional[int] = None, randaug_magnitude: Optional[int] = None, cache: bool = False, no_repeat: bool = False, ): """Creates a split from the ImageNet dataset using TensorFlow Datasets. Args: dataset_builder: TFDS dataset builder for ImageNet. batch_size: the batch size returned by the data pipeline. train: Whether to load the train or evaluation split. half_precision: convert image datatype to half-precision image_size: The target size of the images (default: 224). mean: image dataset mean std: image dataset std-dev interpolation: interpolation method to use for image resize (default: 'bicubic') cache: Whether to cache the dataset (default: False). no_repeat: disable repeat iter for evaluation Returns: A `tf.data.Dataset`. """ mean = mean or MEAN_RGB std = std or STDDEV_RGB interpolation = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR platform = jax.local_devices()[0].platform if half_precision: if platform == 'tpu': input_dtype = tf.bfloat16 else: input_dtype = tf.float16 else: input_dtype = tf.float32 if train: data_size = dataset_builder.info.splits['train'].num_examples split = 'train' else: data_size = dataset_builder.info.splits['validation'].num_examples split = 'validation' split_size = data_size // jax.host_count() start = jax.host_id() * split_size split = split + '[{}:{}]'.format(start, start + split_size) def _decode_example(example): if train: image = preprocess_for_train(example['image'], input_dtype, image_size, mean, std, interpolation, augment_name=augment_name, randaug_num_layers=randaug_num_layers, randaug_magnitude=randaug_magnitude) else: image = preprocess_for_eval(example['image'], input_dtype, image_size, mean, std, interpolation) return {'image': image, 'label': example['label']} ds = dataset_builder.as_dataset( split=split, decoders={'image': tfds.decode.SkipDecoding()}) ds.options().experimental_threading.private_threadpool_size = 16 ds.options().experimental_threading.max_intra_op_parallelism = 1 if cache: ds = ds.cache() if train: ds = ds.repeat() ds = ds.shuffle(16 * batch_size, seed=0) ds = ds.map(_decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) if not train and not no_repeat: ds = ds.repeat() ds = ds.prefetch(10) return ds