Ejemplo n.º 1
0
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
               tfds_split: Text,
               tfds_skip_decoding_feature: Text,
               tfds_as_supervised: bool,
               input_context: Optional[tf.distribute.InputContext] = None,
               seed: Optional[Union[int, tf.Tensor]] = None,
               is_training: bool = False,
               cache: bool = False,
               cycle_length: Optional[int] = None,
               block_length: Optional[int] = None) -> tf.data.Dataset:
    """Reads a dataset from tfds."""
    # No op if exist.
    tfds_builder.download_and_prepare()
    decoders = {}
    if tfds_skip_decoding_feature:
        for skip_feature in tfds_skip_decoding_feature.split(','):
            decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
    if tfds_builder.info.splits:
        num_shards = len(
            tfds_builder.info.splits[tfds_split].file_instructions)
    else:
        # The tfds mock path often does not provide splits.
        num_shards = 1
    if input_context and num_shards < input_context.num_input_pipelines:
        # The number of files in the dataset split is smaller than the number of
        # input pipelines. We read the entire dataset first and then shard in the
        # host memory.
        read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                      interleave_block_length=block_length,
                                      input_context=None,
                                      shuffle_seed=seed)
        dataset = tfds_builder.as_dataset(split=tfds_split,
                                          shuffle_files=is_training,
                                          as_supervised=tfds_as_supervised,
                                          decoders=decoders,
                                          read_config=read_config)
        dataset = dataset.shard(input_context.num_input_pipelines,
                                input_context.input_pipeline_id)
    else:
        read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                      interleave_block_length=block_length,
                                      input_context=input_context,
                                      shuffle_seed=seed)
        dataset = tfds_builder.as_dataset(split=tfds_split,
                                          shuffle_files=is_training,
                                          as_supervised=tfds_as_supervised,
                                          decoders=decoders,
                                          read_config=read_config)

    if is_training and not cache:
        dataset = dataset.repeat()
    return dataset
Ejemplo n.º 2
0
def create_split(dataset_builder: tfds.core.DatasetBuilder,
                 batch_size: int,
                 train: bool,
                 dtype: tf.DType = tf.float32,
                 image_size: int = IMAGE_SIZE,
                 cache: bool = False):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

  Args:
    dataset_builder: TFDS dataset builder for ImageNet.
    batch_size: the batch size returned by the data pipeline.
    train: Whether to load the train or evaluation split.
    dtype: data type of the image (default: float32).
    image_size: The target size of the images (default: 224).
    cache: Whether to cache the dataset (default: False).
  Returns:
    A `tf.data.Dataset`.
  """
    if train:
        train_size = dataset_builder.info.splits['train'].num_examples
        split_size = train_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'train[{}:{}]'.format(start, start + split_size)
    else:
        validation_size = dataset_builder.info.splits[
            'validation'].num_examples
        split_size = validation_size // jax.host_count()
        start = jax.host_id() * split_size
        split = 'validation[{}:{}]'.format(start, start + split_size)

    def _decode_example(example):
        if train:
            image = preprocess_for_train(example['image'], dtype, image_size)
        else:
            image = preprocess_for_eval(example['image'], dtype, image_size)
        return {'image': image, 'label': example['label']}

    ds = dataset_builder.as_dataset(
        split=split, decoders={'image': tfds.decode.SkipDecoding()})
    ds.options().experimental_threading.private_threadpool_size = 48
    ds.options().experimental_threading.max_intra_op_parallelism = 1

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(_decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

    ds = ds.batch(batch_size, drop_remainder=True)

    if not train:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds
Ejemplo n.º 3
0
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
               tfds_split: Text,
               tfds_skip_decoding_feature: Text,
               tfds_as_supervised: bool,
               input_context: Optional[tf.distribute.InputContext] = None,
               seed: Optional[Union[int, tf.Tensor]] = None,
               is_training: bool = False,
               cache: bool = False,
               cycle_length: Optional[int] = None,
               block_length: Optional[int] = None) -> tf.data.Dataset:
    """Reads a dataset from tfds."""
    # No op if exist.
    tfds_builder.download_and_prepare()

    read_config = tfds.ReadConfig(interleave_cycle_length=cycle_length,
                                  interleave_block_length=block_length,
                                  input_context=input_context,
                                  shuffle_seed=seed)
    decoders = {}
    if tfds_skip_decoding_feature:
        for skip_feature in tfds_skip_decoding_feature.split(','):
            decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
    dataset = tfds_builder.as_dataset(split=tfds_split,
                                      shuffle_files=is_training,
                                      as_supervised=tfds_as_supervised,
                                      decoders=decoders,
                                      read_config=read_config)

    if is_training and not cache:
        dataset = dataset.repeat()
    return dataset
Ejemplo n.º 4
0
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder,
                    split: str,
                    *,
                    reverse_translation: bool = False) -> tf.data.Dataset:
    """Loads a raw WMT dataset and normalizes feature keys.

  Args:
    dataset_builder: TFDS dataset builder that can build `slit`.
    split: Split to use. This must be the full split. We shard the split across
      multiple hosts and currently don't support sharding subsplits.
    reverse_translation: bool: whether to reverse the translation direction.
      e.g. for 'de-en' this translates from english to german.

  Returns:
    Dataset with source and target language features mapped to 'inputs' and
    'targets'.
  """
    num_examples = dataset_builder.info.splits[split].num_examples
    per_host_split = deterministic_data.get_read_instruction_for_host(
        split, num_examples, drop_remainder=False)
    ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False)
    ds = ds.map(NormalizeFetaureNamesOp(
        dataset_builder.info, reverse_translation=reverse_translation),
                num_parallel_calls=AUTOTUNE)
    return ds
Ejemplo n.º 5
0
def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder,
                    split: str) -> tf.data.Dataset:
    """Loads a raw text dataset and normalizes feature keys.

  Args:
    dataset_builder: TFDS dataset builder that can build `split`.
    split: Split to use. This must be the full split. We shard the split across
      multiple hosts and currently don't support sharding subsplits.

  Returns:
    Dataset with source and target language features mapped to 'inputs' and
    'targets'.
  """
    per_host_split = deterministic_data.get_read_instruction_for_host(
        split, dataset_info=dataset_builder.info, drop_remainder=False)
    ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False)
    ds = ds.map(NormalizeFeatureNamesOp(dataset_builder.info),
                num_parallel_calls=AUTOTUNE)
    return ds
def create_split(
    dataset_builder: tfds.core.DatasetBuilder,
    batch_size: int,
    train: bool = True,
    half_precision: bool = False,
    image_size: int = IMAGE_SIZE,
    mean: Optional[Tuple[float]] = None,
    std: Optional[Tuple[float]] = None,
    interpolation: str = 'bicubic',
    augment_name: Optional[str] = None,
    randaug_num_layers: Optional[int] = None,
    randaug_magnitude: Optional[int] = None,
    cache: bool = False,
    no_repeat: bool = False,
):
    """Creates a split from the ImageNet dataset using TensorFlow Datasets.

    Args:
      dataset_builder: TFDS dataset builder for ImageNet.
      batch_size: the batch size returned by the data pipeline.
      train: Whether to load the train or evaluation split.
      half_precision: convert image datatype to half-precision
      image_size: The target size of the images (default: 224).
      mean: image dataset mean
      std: image dataset std-dev
      interpolation: interpolation method to use for image resize (default: 'bicubic')
      cache: Whether to cache the dataset (default: False).
      no_repeat: disable repeat iter for evaluation
    Returns:
      A `tf.data.Dataset`.
    """
    mean = mean or MEAN_RGB
    std = std or STDDEV_RGB
    interpolation = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
    platform = jax.local_devices()[0].platform
    if half_precision:
        if platform == 'tpu':
            input_dtype = tf.bfloat16
        else:
            input_dtype = tf.float16
    else:
        input_dtype = tf.float32

    if train:
        data_size = dataset_builder.info.splits['train'].num_examples
        split = 'train'
    else:
        data_size = dataset_builder.info.splits['validation'].num_examples
        split = 'validation'
    split_size = data_size // jax.host_count()
    start = jax.host_id() * split_size
    split = split + '[{}:{}]'.format(start, start + split_size)

    def _decode_example(example):
        if train:
            image = preprocess_for_train(example['image'],
                                         input_dtype,
                                         image_size,
                                         mean,
                                         std,
                                         interpolation,
                                         augment_name=augment_name,
                                         randaug_num_layers=randaug_num_layers,
                                         randaug_magnitude=randaug_magnitude)
        else:
            image = preprocess_for_eval(example['image'], input_dtype,
                                        image_size, mean, std, interpolation)
        return {'image': image, 'label': example['label']}

    ds = dataset_builder.as_dataset(
        split=split, decoders={'image': tfds.decode.SkipDecoding()})
    ds.options().experimental_threading.private_threadpool_size = 16
    ds.options().experimental_threading.max_intra_op_parallelism = 1

    if cache:
        ds = ds.cache()

    if train:
        ds = ds.repeat()
        ds = ds.shuffle(16 * batch_size, seed=0)

    ds = ds.map(_decode_example,
                num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=True)

    if not train and not no_repeat:
        ds = ds.repeat()

    ds = ds.prefetch(10)

    return ds